about summary refs log tree commit diff
path: root/main.py
diff options
context:
space:
mode:
authorUltraQbik <no1skill@yandex.ru>2024-08-21 04:58:18 +0300
committerUltraQbik <no1skill@yandex.ru>2024-08-21 04:58:18 +0300
commit12ceade05839f57c578c318cf476766a30869799 (patch)
treea8e3850a02ec0f66312b0d57fa6e2c00a98ab123 /main.py
parenta68496d2111c8971a77e52329ee3f22a9ddc4cea (diff)
downloadhttpy-12ceade05839f57c578c318cf476766a30869799.tar.gz
httpy-12ceade05839f57c578c318cf476766a30869799.zip
A bit more asynchronous code?
I have refactored some code and got rid of a few libraries

Disabled logging as it's not very important at the time
Diffstat (limited to 'main.py')
-rw-r--r--main.py68
1 files changed, 33 insertions, 35 deletions
diff --git a/main.py b/main.py
index 7270ddc..5aeae40 100644
--- a/main.py
+++ b/main.py
@@ -4,12 +4,10 @@ The mighty silly webserver written in python for no good reason
 
 
 import ssl
-import time
 import gzip
 import socket
 import asyncio
 import aiofiles
-import threading
 from src.request import Request
 
 
@@ -20,6 +18,7 @@ PATH_MAP = {
     "/robots.txt":              {"path": "www/robots.txt"},
     "/favicon.ico":             {"path": "www/favicon.ico"},
     "/css/styles.css":          {"path": "css/styles.css"},
+    "/about.html":              {"path": "www/about.html"},
 }
 
 # internal path map
@@ -77,50 +76,36 @@ class HTTPServer:
         self.socket.bind(('', self.bind_port))
         self.socket.listen()
 
-        # start the listening thread
-        threading.Thread(target=self._listen_thread, daemon=True).start()
+        # start listening
+        asyncio.run(self._listen_thread())
 
-        # keep alive
-        try:
-            while True:
-                # sleep 100 ms, otherwise the while true will 100% one of your cores
-                time.sleep(0.1)
-
-        # shutdown on keyboard interrupt
-        except KeyboardInterrupt:
-            self.socket.close()
-            print("Closed.")
-
-    def _listen_thread(self):
+    async def _listen_thread(self):
         """
         Listening for new connections
         """
 
-        # run the coroutine
-        asyncio.run(self._thread_listen_coro())
-
-    async def _thread_listen_coro(self):
+        loop = asyncio.get_event_loop()
         while True:
             # accept new connection, add to client list and start listening to it
-            client, _ = self.socket.accept()
+            client = (await self._accept(self.socket))[0]
             self.clients.append(client)
-            await self.client_handle(client)
+            await loop.create_task(self.client_handle(client))
 
-    async def client_handle(self, client: socket.socket):
+    async def client_handle(self, client: ssl.SSLSocket):
         """
         Handles client's connection
         """
 
         while True:
             # receive request from client
-            raw_request = self._recvall(client)
+            raw_request = await self._recvall(client)
 
             # decode request
             request: Request = Request.create(raw_request)
 
-            # log request
-            async with aiofiles.open("logs.log", "a") as f:
-                await f.write(f"IP: {client.getpeername()[0]}\n{request}\n\n")
+            # # log request
+            # async with aiofiles.open("logs.log", "a") as f:
+            #     await f.write(f"IP: {client.getpeername()[0]}\n{request}\n\n")
 
             # handle requests
             try:
@@ -142,7 +127,7 @@ class HTTPServer:
         self._close_client(client)
 
     @staticmethod
-    async def handle_get_request(client: socket.socket, request: Request):
+    async def handle_get_request(client: ssl.SSLSocket, request: Request):
         """
         Handles user's GET request
         :param client: client
@@ -164,7 +149,7 @@ class HTTPServer:
                 headers["Content-Encoding"] = "gzip"
 
             # send 200 response with the file to the client
-            HTTPServer._send(client, 200, data, headers)
+            await HTTPServer._send(client, 200, data, headers)
         else:
             # in case of error, return error page
             async with aiofiles.open(I_PATH_MAP["/err/response.html"]["path"], "r") as f:
@@ -177,10 +162,10 @@ class HTTPServer:
             data = data.format(status_code=get_response_code(status_code).decode("ascii"))
 
             # send 404 response to the client
-            HTTPServer._send(client, status_code, data.encode("ascii"))
+            await HTTPServer._send(client, status_code, data.encode("ascii"))
 
     @staticmethod
-    def _send(client: socket.socket, response: int, data: bytes = None, headers: dict[str, str] = None):
+    async def _send(client: ssl.SSLSocket, response: int, data: bytes = None, headers: dict[str, str] = None):
         """
         Sends client response code + headers + data
         :param client: client
@@ -208,11 +193,12 @@ class HTTPServer:
             byte_header += f"{key}: {value}\r\n".encode("ascii")
 
         # send response to the client
-        client.sendall(
+        await HTTPServer._sendall(
+            client,
             b'HTTP/1.1 ' +
             get_response_code(response) +
             b'\r\n' +
-            byte_header +       # if empty, we'll just get b'\r\n\r\n'
+            byte_header +  # if empty, we'll just get b'\r\n\r\n'
             b'\r\n' +
             data
         )
@@ -226,7 +212,7 @@ class HTTPServer:
         if client in self.clients:
             self.clients.remove(client)
 
-    def _recvall(self, client: socket.socket) -> bytes:
+    async def _recvall(self, client: ssl.SSLSocket) -> bytes:
         """
         Receive All (just receives the whole message, instead of 1 packet at a time)
         """
@@ -238,7 +224,7 @@ class HTTPServer:
         while True:
             try:
                 # fetch packet
-                message = client.recv(self.packet_size)
+                message = await self._recv(client, self.packet_size)
             except OSError:
                 break
 
@@ -257,6 +243,18 @@ class HTTPServer:
         # return empty buffer on error
         return b''
 
+    @staticmethod
+    async def _accept(sock: ssl.SSLSocket) -> tuple[ssl.SSLSocket, str]:
+        return sock.accept()
+
+    @staticmethod
+    async def _recv(sock: ssl.SSLSocket, buflen: int = 1024):
+        return sock.recv(buflen)
+
+    @staticmethod
+    async def _sendall(sock: ssl.SSLSocket, data: bytes):
+        sock.sendall(data)
+
 
 def main():
     server = HTTPServer(port=13700)