about summary refs log tree commit diff
path: root/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'main.py')
-rw-r--r--main.py75
1 files changed, 39 insertions, 36 deletions
diff --git a/main.py b/main.py
index 715513e..f03b354 100644
--- a/main.py
+++ b/main.py
@@ -11,12 +11,16 @@ import brotli
 import signal
 import threading
 from src import APIv1
+from src.config import *
 from src.request import *
 from src.status_code import *
-from src.config import BUFFER_LENGTH
 from src.minimizer import minimize_html
 
 
+# typing
+usocket = socket.socket | ssl.SSLSocket
+
+
 # path mapping
 path_map = {
     "/":                    {"path": "www/index.html"},
@@ -72,19 +76,20 @@ class HTTPServer:
     Now uses threading
     """
 
-    def __init__(self, *, port: int, packet_size: int = BUFFER_LENGTH):
-        # SSL context
-        context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
-        context.options &= ssl.OP_NO_SSLv3
-        context.check_hostname = False
-        context.load_cert_chain(
-            certfile=r"C:\Certbot\live\qubane.ddns.net\fullchain.pem",  # use your own path here
-            keyfile=r"C:\Certbot\live\qubane.ddns.net\privkey.pem")     # here too
-
+    def __init__(self, *, port: int, packet_size: int = BUFFER_LENGTH, enable_ssl: bool = True):
         # Sockets
-        self.sock: ssl.SSLSocket = context.wrap_socket(
-            socket.socket(socket.AF_INET, socket.SOCK_STREAM),
-            server_side=True)
+        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        if enable_ssl:
+            # SSL context
+            context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+            context.options &= ssl.OP_NO_SSLv3
+            context.check_hostname = False
+            context.load_cert_chain(
+                certfile=r"C:\Certbot\live\qubane.ddns.net\fullchain.pem",  # use your own path here
+                keyfile=r"C:\Certbot\live\qubane.ddns.net\privkey.pem")  # here too
+            self.sock: usocket = context.wrap_socket(sock, server_side=True)
+        else:
+            self.sock: usocket = sock
         self.buf_len: int = packet_size
         self.port: int = port
 
@@ -130,7 +135,7 @@ class HTTPServer:
         # close server socket
         self.sock.close()
 
-    def _client_thread(self, client: ssl.SSLSocket):
+    def _client_thread(self, client: usocket):
         """
         Handles getting client's requests
         :param client: client ssl socket
@@ -139,11 +144,11 @@ class HTTPServer:
         try:
             request = self._recv_request(client)
             if request is not None:
-                print(
-                    f"ip: {client.getpeername()[0]}\n"
-                    f"{request.type}\n"
-                    f"{request.path}\n"
-                    f"{request.path_args}", end="\n\n")
+                # print(
+                #     f"ip: {client.getpeername()[0]}\n"
+                #     f"{request.type}\n"
+                #     f"{request.path}\n"
+                #     f"{request.path_args}", end="\n\n")
                 self._client_request_handler(client, request)
         except ssl.SSLEOFError:
             pass
@@ -152,10 +157,11 @@ class HTTPServer:
         except Exception as e:
             print(e)
 
-        # close the connection once stop even was set or an error occurred
+        # Remove self from thread list and close the connection
+        self.client_threads.remove(threading.current_thread())
         client.close()
 
-    def _client_request_handler(self, client: ssl.SSLSocket, request: Request):
+    def _client_request_handler(self, client: usocket, request: Request):
         """
         Handles responses to client's requests
         :param client: client
@@ -163,7 +169,7 @@ class HTTPServer:
         """
 
         match request.type:
-            case "GET" | "HEAD":
+            case "GET":
                 response = self._handle_get(client, request)
             # case "POST":  # Not Implemented
             #     response = self._handle_post(client, request)
@@ -203,9 +209,9 @@ class HTTPServer:
             if self.stop_event.is_set():
                 break
 
-    def _handle_get(self, client: ssl.SSLSocket, request: Request) -> Response:
+    def _handle_get(self, client: usocket, request: Request) -> Response:
         """
-        Handles GET / HEAD requests from a client
+        Handles GET requests from a client
         """
 
         split_path = request.path.split("/", maxsplit=16)[1:]
@@ -217,13 +223,7 @@ class HTTPServer:
             if filepath[-4:] == "html":
                 data = minimize_html(data)
 
-            if request.type == "GET":
-                response = Response(data, STATUS_CODE_OK)
-            elif request.type == "HEAD":
-                response = Response(b'', STATUS_CODE_OK, {"Content-Length": len(data)})
-            else:
-                raise TypeError("Called GET handler for non-GET request")
-
+            response = Response(data, STATUS_CODE_OK)
             if filepath[-4:] == "webp":
                 response.compress = False
 
@@ -232,7 +232,7 @@ class HTTPServer:
         elif len(split_path) >= 2 and split_path[0] in API_VERSIONS:  # assume script
             # unsupported API version
             if not API_VERSIONS[split_path[0]]:
-                if request.type == "GET" or request.type == "HEAD":
+                if request.type == "GET":
                     return Response(b'', STATUS_CODE_BAD_REQUEST)
                 else:
                     raise TypeError("Called GET handler for non-GET request")
@@ -245,12 +245,12 @@ class HTTPServer:
             data = data.format(status_code=str(STATUS_CODE_NOT_FOUND)).encode("utf8")
             return Response(data, STATUS_CODE_NOT_FOUND)
 
-    def _handle_post(self, client: ssl.SSLSocket, request: Request) -> Response:
+    def _handle_post(self, client: usocket, request: Request) -> Response:
         """
         Handles POSt request from a client
         """
 
-    def _recv_request(self, client: ssl.SSLSocket) -> Request | None:
+    def _recv_request(self, client: usocket) -> Request | None:
         """
         Receive request from client
         :return: request
@@ -265,16 +265,19 @@ class HTTPServer:
             buffer += msg
             if buffer[-4:] == b'\r\n\r\n':
                 return Request.create(buffer)
+            if len(buffer) > BUFFER_MAX_SIZE:  # ignore big messages
+                break
         return None
 
-    def _accept(self) -> ssl.SSLSocket | None:
+    def _accept(self) -> usocket | None:
         """
         socket.accept, but for more graceful closing
         """
 
         while not self.stop_event.is_set():
             try:
-                return self.sock.accept()[0]
+                if len(self.client_threads) < CLIENT_MAX_AMOUNT:
+                    return self.sock.accept()[0]
             except BlockingIOError:
                 time.sleep(0.005)
             except ssl.SSLEOFError: