about summary refs log tree commit diff
path: root/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'main.py')
-rw-r--r--main.py103
1 files changed, 51 insertions, 52 deletions
diff --git a/main.py b/main.py
index 43ebac2..0b9f991 100644
--- a/main.py
+++ b/main.py
@@ -8,8 +8,7 @@ import gzip
 import socket
 import brotli
 import signal
-import asyncio
-import aiofiles
+import threading
 from src import APIv1
 from src.socks import *
 from src.request import Request
@@ -99,6 +98,8 @@ class HTTPServer:
         """
 
         self.socket.close()
+        for client in self.clients:
+            client.close()
 
     def start(self):
         """
@@ -114,21 +115,18 @@ class HTTPServer:
         self.socket.setblocking(False)
 
         # start listening
-        asyncio.run(self._listen_thread())
+        self._listen_thread()
 
-    async def _listen_thread(self):
+    def _listen_thread(self):
         """
         Listening for new connections
         """
 
-        # get event loop
-        loop = asyncio.get_event_loop()
-
         # start listening
         while True:
             # try to accept new connection
             try:
-                client = (await ssl_sock_accept(self.socket))[0]
+                client = ssl_sock_accept(self.socket)[0]
 
             # if socket was closed -> break
             except OSError as e:
@@ -137,28 +135,35 @@ class HTTPServer:
                 break
 
             # else append to client list and create new task
-            await loop.create_task(self.client_handle(client))
+            threading.Thread(target=self.client_handle, args=[client]).start()
 
-    async def client_handle(self, client: ssl.SSLSocket):
+    def client_handle(self, client: ssl.SSLSocket):
         """
         Handles client's connection
         """
 
-        # receive request from client
-        raw_request = await self._recvall(client)
+        while True:
+            # receive request from client
+            raw_request = self._recvall(client)
+
+            if raw_request == b'':
+                break
 
-        # decode request
-        request: Request = Request.create(raw_request)
+            # decode request
+            request: Request = Request.create(raw_request)
 
-        # # log request
-        # async with aiofiles.open("logs.log", "a") as f:
-        #     await f.write(f"IP: {client.getpeername()[0]}\n{request}\n\n")
+            # # log request
+            # async with aiofiles.open("logs.log", "a") as f:
+            #     await f.write(f"IP: {client.getpeername()[0]}\n{request}\n\n")
 
+            threading.Thread(target=self.handle_request, args=[client, request]).start()
+
+    def handle_request(self, client: ssl.SSLSocket, request: Request):
         # handle requests
         try:
             match request.type:
                 case "GET":
-                    await self.handle_get_request(client, request)
+                    self.handle_get_request(client, request)
                 case _:
                     pass
 
@@ -166,11 +171,11 @@ class HTTPServer:
         except Exception as e:
             print(e)
 
-        # close connection (stop page loading)
-        self._close_client(client)
+        # # close connection (stop page loading)
+        # self._close_client(client)
 
     @staticmethod
-    async def handle_get_request(client: ssl.SSLSocket, request: Request):
+    def handle_get_request(client: ssl.SSLSocket, request: Request):
         """
         Handles user's GET request
         """
@@ -181,8 +186,8 @@ class HTTPServer:
         # check if request path is in the PATH_MAP
         if request.path in PATH_MAP:
             # if it is -> return file from that path
-            async with aiofiles.open(PATH_MAP[request.path]["path"], "rb") as f:
-                data = await f.read()
+            with open(PATH_MAP[request.path]["path"], "rb") as f:
+                data = f.read()
 
             # pre-compress data for HTML files
             if PATH_MAP[request.path]["path"][-4:] == "html":
@@ -198,10 +203,7 @@ class HTTPServer:
                 headers["Content-Encoding"] = "gzip"
 
             # send 200 response with the file to the client
-            await HTTPServer._send(client, 200, data, headers)
-
-            # return after answer
-            return
+            HTTPServer._send(client, 200, data, headers)
 
         # if it's an API request
         elif (api_version := request.path.split("/")[1]) in API_VERSIONS:
@@ -209,43 +211,41 @@ class HTTPServer:
             headers = {}
             match api_version:
                 case "APIv1":
-                    status, data, headers = await APIv1.respond(client, request)
+                    status, data, headers = APIv1.respond(client, request)
                 case _:
                     status = 400
 
             # if status is not 200 -> send bad response
             if status != 200:
-                await HTTPServer._bad_response(client, status)
+                HTTPServer._bad_response(client, status)
                 return
 
             # send data if no error
-            await HTTPServer._send(client, status, data, headers)
-
-            # return after answer
-            return
+            HTTPServer._send(client, status, data, headers)
 
         # in case of error, return error page
-        await HTTPServer._bad_response(client, 404)
+        else:
+            HTTPServer._bad_response(client, 404)
 
     @staticmethod
-    async def _bad_response(client: ssl.SSLSocket, status_code: int):
+    def _bad_response(client: ssl.SSLSocket, status_code: int):
         """
         Sends a bad response page to the client.
         :param client: client
         :param status_code: status code
         """
 
-        async with aiofiles.open(I_PATH_MAP["/err/response.html"]["path"], "r") as f:
-            data = await f.read()
+        with open(I_PATH_MAP["/err/response.html"]["path"], "r") as f:
+            data = f.read()
 
         # format error response
         data = data.format(status_code=get_response_code(status_code).decode("ascii"))
 
         # send response to the client
-        await HTTPServer._send(client, status_code, data.encode("ascii"))
+        HTTPServer._send(client, status_code, data.encode("ascii"))
 
     @staticmethod
-    async def _send(client: ssl.SSLSocket, response: int, data: bytes = None, headers: dict[str, str] = None):
+    def _send(client: ssl.SSLSocket, response: int, data: bytes = None, headers: dict[str, str] = None):
         """
         Sends client response code + headers + data
         :param client: client
@@ -279,8 +279,7 @@ class HTTPServer:
             byte_header += f"{key}: {value}\r\n".encode("ascii")
 
         # send response to the client
-        await ssl_sock_sendall(
-            client,
+        client.sendall(
             b'HTTP/1.1 ' +
             get_response_code(response) +
             b'\r\n' +
@@ -289,16 +288,7 @@ class HTTPServer:
             data
         )
 
-    def _close_client(self, client: socket.socket):
-        """
-        Closes a client
-        """
-
-        client.close()
-        if client in self.clients:
-            self.clients.remove(client)
-
-    async def _recvall(self, client: ssl.SSLSocket) -> bytes:
+    def _recvall(self, client: ssl.SSLSocket) -> bytes:
         """
         Receive All (just receives the whole message, instead of 1 packet at a time)
         """
@@ -310,7 +300,7 @@ class HTTPServer:
         while True:
             try:
                 # fetch packet
-                message = await ssl_sock_recv(client, self.packet_size)
+                message = ssl_sock_recv(client, self.packet_size)
             except OSError:
                 break
 
@@ -329,6 +319,15 @@ class HTTPServer:
         # return empty buffer on error
         return b''
 
+    def _close_client(self, client: socket.socket):
+        """
+        Closes a client
+        """
+
+        client.close()
+        if client in self.clients:
+            self.clients.remove(client)
+
 
 def main():
     server = HTTPServer(port=13700)