about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--main.py287
1 files changed, 96 insertions, 191 deletions
diff --git a/main.py b/main.py
index 4078d99..7889ef8 100644
--- a/main.py
+++ b/main.py
@@ -13,250 +13,155 @@ import aiofiles
 import threading
 
 
-# some constants
-PACKET_SIZE = 2048
-PORT = 13700            # using random port cuz why not
-
-
-# response status codes
-RESPONSE = {
-    200: b'OK',
-    400: b'Bad Request',
-    401: b'Unauthorized',
-    403: b'Forbidden',
-    404: b'Not Found',
-    6969: b'UwU'
-}
-
-
-def get_response(code: int) -> bytes:
-    return str(code).encode("ascii") + RESPONSE.get(code, b':(')
-
-
-def is_alive(sock: socket.socket) -> bool:
+class Request:
     """
-    Checks if the socket is still alive
-    :param sock: socket
-    :return: boolean (true if socket is alive, false otherwise)
+    Just a request
     """
-    return getattr(sock, '_closed', False)
 
+    def __init__(self):
+        self.type: str = ""
+        self.path: str = ""
+
+    @staticmethod
+    def create(raw_request: bytes):
+        """
+        Creates self class from raw request
+        :param raw_request: bytes
+        :return: self
+        """
+
+        # new request
+        request = Request()
 
-def decode_request(req: str) -> dict[str, str | list | None]:
-    # request dictionary
-    request = dict()
+        # fix type and path
+        request.type = raw_request[:raw_request.find(b' ')].decode("ascii")
+        request.path = raw_request[len(request.type)+1:raw_request.find(b' ', len(request.type)+1)].decode("ascii")
 
-    # request type and path
-    request["type"] = req[:6].split(" ")[0]
-    request["path"] = req[len(request["type"]) + 1:req.find("\r\n")].split(" ")[0]
+        # decode headers
+        for raw_header in raw_request.split(b'\r\n'):
+            if len(pair := raw_header.decode("ascii").split(":")) == 2:
+                key, val = pair
+                val = val.strip()
 
-    # decode other headers
-    for line in req.split("\r\n")[1:]:
-        if len(split := line.split(":")) == 2:
-            key, value = split
-            value = value.lstrip(" ")
+                # set attribute to key value pair
+                setattr(request, key, val)
 
-            # write key value pair
-            request[key] = value
+        # return request
+        return request
 
-    return request
+    def __str__(self):
+        return '\n'.join([f"{key}: {val}" for key, val in self.__dict__.items()])
 
 
-class HTMLServer:
+class HTTPServer:
     """
-    The very cool webserver
+    The mighty HTTP server
     """
 
-    def __init__(self):
-        self.sock: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-        self.clients: list[socket.socket] = []
+    def __init__(self, *, port: int, packet_size: int = 2048):
+        self.bind_port: int = port
+        self.packet_size: int = packet_size
+        self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
-        # list of allowed paths
-        self.allowed_paths: dict[str, dict] = {
-            "/":                {"path": "www/index.html",      "encoding": "css/html"},
-            "/robots.txt":      {"path": "www/robots.txt",      "encoding": "text"},
-            "/favicon.ico":     {"path": "www/favicon.ico",     "encoding": "bin"},
-            "/css/styles.css":  {"path": None,                  "encoding": "css/html"},
-        }
+        self.clients: list[socket.socket] = []
 
-    def run(self):
+    def start(self):
         """
-        Function that starts the webserver
+        Method to start the web server
         """
 
-        # bind the server to port and start listening
-        self.sock.bind(('', PORT))
-        self.sock.listen()
+        # bind and start listening to port
+        self.socket.bind(('', self.bind_port))
+        self.socket.listen()
 
-        # start running thread
-        t = threading.Thread(target=self._run, daemon=True)
-        t.start()
+        # start the listening thread
+        threading.Thread(target=self._listen_thread, daemon=True).start()
 
         # keep alive
         try:
             while True:
+                # sleep 100 ms, otherwise the while true will 100% one of your cores
                 time.sleep(0.1)
+
+        # shutdown on keyboard interrupt
         except KeyboardInterrupt:
-            self.sock.close()
+            self.socket.close()
             print("Closed.")
 
-    def _run(self):
+    def _listen_thread(self):
         """
-        Run function for threads
-        :return:
+        Listening for new connections
         """
 
-        asyncio.run(self.server_listener())
+        # run the coroutine
+        asyncio.run(self._thread_listen_coro())
+
+    async def _thread_listen_coro(self):
+        while True:
+            # accept new connection, add to client list and start listening to it
+            client, _ = self.socket.accept()
+            self.clients.append(client)
+            await self.client_handle(client)
 
-    async def server_listener(self):
+    async def client_handle(self, client: socket.socket):
         """
-        Listens for new connections, and handles them
+        Handles client's connection
         """
 
         while True:
-            client, address = self.sock.accept()
-            self.clients.append(client)
-            await self.server_handle(client)
+            # receive request from client
+            raw_request = self._recvall(client)
+
+            # decode request
+            request: Request = Request.create(raw_request)
+
+            self._close_client(client)
+            break
 
-    async def server_handle(self, client: socket.socket):
+    def _close_client(self, client: socket.socket):
         """
-        Handles the actual connections (clients)
-        :param client: connection socket
+        Closes a client
         """
 
-        # message buffer
-        buffer = bytearray()
+        client.close()
+        if client in self.clients:
+            self.clients.remove(client)
+
+    def _recvall(self, client: socket.socket) -> bytes:
+        """
+        Receive All (just receives the whole message, instead of 1 packet at a time)
+        """
+
+        # create message buffer
+        buffer: bytearray = bytearray()
+
+        # start fetching the message
         while True:
-            # try to fetch a message
-            # die otherwise
             try:
-                message = client.recv(PACKET_SIZE)
+                # fetch packet
+                message = client.recv(self.packet_size)
             except OSError:
                 break
+
+            # that happens when user stops loading the page
             if message == b'':
                 break
 
-            # append packet to buffer
+            # append fetched message to the buffer
             buffer += message
 
-            # check EoF (2 blank lines)
+            # check for EoF
             if buffer[-4:] == b'\r\n\r\n':
-                # text buffer
-                text_buffer = buffer.decode("ascii")
-
-                # decode request
-                request = decode_request(text_buffer)
-
-                print(f"[{request['type']}] Request from client '{client.getpeername()[0]}'")
-
-                # log that request
-                async with aiofiles.open("logs.log", "a") as f:
-                    await f.write(
-                        json.dumps(
-                            {
-                                "client": client.getpeername()[0],
-                                "request": request
-                            },
-                            indent=2
-                        ) + "\n"
-                    )
-
-                # handle the request
-                if request["type"] == "GET":
-                    await self.handle_get_request(client, request)
-                else:
-                    await self.handle_other_request(client)
-
-                # clear buffer
-                buffer.clear()
-        client.close()
-        self.clients.remove(client)
-
-    async def handle_get_request(self, client: socket.socket, req: dict[str, str | None]):
-        # check if the path is too long
-        if len(req["path"]) > 255:
-            response = get_response(400)
-            data = b''
-
-        # if it's yandex
-        elif req.get("from") == "support@search.yandex.ru":
-            response = get_response(404)
-            data = b'Nothing...'
-
-        # check UwU path
-        elif req["path"] == "/UwU":
-            response = get_response(6969)
-            data = b'<h1>' + b'UwU ' * 2000 + b'</h1>'
-
-        # otherwise check access
-        elif req["path"] in self.allowed_paths:
-            # get path
-            path = self.allowed_paths[req["path"]]["path"]
-
-            # if path is None, return generic filepath
-            if path is None:
-                path = req["path"][1:]
-
-            # check encoding
-            if self.allowed_paths[req["path"]]["encoding"] == "css/html":
-                # return text data
-                async with aiofiles.open(path, "r") as f:
-                    data = htmlmin.minify(await f.read()).encode("ascii")
-            else:
-                # return binary / text data
-                async with aiofiles.open(path, "rb") as f:
-                    data = await f.read()
-            response = get_response(200)
-
-        # in any other case
-        else:
-            response = get_response(403)
-            data = b'Idk what you are trying to do here :/'
-
-        # make headers
-        headers = {}
-
-        # check if compression is supported
-        if req.get("Accept-Encoding"):
-            encoding_list = [enc.lstrip(" ") for enc in req["Accept-Encoding"].split(",")]
-
-            # check for gzip, and add to headers if present
-            if "gzip" in encoding_list:
-                headers["Content-Encoding"] = "gzip"
-
-        # send response
-        await self.send(client, response, data, headers)
-        client.close()
-
-    async def handle_other_request(self, client: socket.socket):
-        # just say 'no'
-        await self.send(client, get_response(403), b'No. Don\'t do that, that\'s cringe')
-        client.close()
-
-    async def send(self, client: socket.socket, response: bytes, data: bytes, headers: dict[str, str] | None = None):
-        # construct headers
-        formatted_headers = b''
-        if headers is not None:
-            formatted_headers = "".join([f"{key}: {val}\r\n" for key, val in headers.items()]).encode("ascii")
-
-            # check for compression
-            if headers.get("Content-Encoding") == "gzip":
-                # compress data
-                data = gzip.compress(data)
-
-        # construct message
-        if formatted_headers == b'':
-            message = b'HTTP/1.1 ' + response + b'\r\n\r\n' + data
-        else:
-            message = b'HTTP/1.1 ' + response + b'\r\n' + formatted_headers + b'\r\n' + data
+                # return the received message
+                return buffer
 
-        # send message to client
-        client.sendall(message)
+        # return empty buffer on error
+        return b''
 
 
 def main():
-    server = HTMLServer()
-    server.run()
+    server = HTTPServer(port=13701)
+    server.start()
 
 
 if __name__ == '__main__':