about summary refs log tree commit diff
path: root/main.py
diff options
context:
space:
mode:
authorUltraQbik <no1skill@yandex.ru>2024-08-23 15:44:06 +0300
committerUltraQbik <no1skill@yandex.ru>2024-08-23 15:44:06 +0300
commit46a76eff285459d25f142d2ce2628425d4e69e94 (patch)
treef8019e5cd8e53a72541102107521106478330bbe /main.py
parentc95930a8d1e9724a0b720c6fe4e648d0ee267800 (diff)
downloadhttpy-46a76eff285459d25f142d2ce2628425d4e69e94.tar.gz
httpy-46a76eff285459d25f142d2ce2628425d4e69e94.zip
More graceful shutdown
There is a weird issue when a user tries to load a page, doesn't get a response (because that wasn't implemented yet), and when the webserver is attempted to be shutdown, it still tries to fetch user's request
Diffstat (limited to 'main.py')
-rw-r--r--main.py480
1 files changed, 231 insertions, 249 deletions
diff --git a/main.py b/main.py
index 0ab634f..4edd2ba 100644
--- a/main.py
+++ b/main.py
@@ -4,16 +4,16 @@ The mighty silly webserver written in python for no good reason
 
 
 import ssl
+import time
 import gzip
 import socket
 import brotli
 import signal
 import threading
 from src import APIv1
-from src.socks import *
+from src.status_code import *
 from src.request import Request
 from src.minimizer import minimize_html
-from src.status_code import *
 
 
 # path mapping
@@ -56,12 +56,11 @@ class HTTPServer:
         self.sock: ssl.SSLSocket = context.wrap_socket(
             socket.socket(socket.AF_INET, socket.SOCK_STREAM),
             server_side=True)
-        self.packet_size: int = packet_size
+        self.buf_len: int = packet_size
         self.port: int = port
 
         # client thread list and server thread
         self.client_threads: list[threading.Thread] = []
-        self.server_thread: threading.Thread | None = None
 
         # add signaling
         self.stop_event = threading.Event()
@@ -76,10 +75,6 @@ class HTTPServer:
         self.stop_event.set()
         for thread in self.client_threads:
             thread.join()
-        self.server_thread.join()
-
-        # close server socket
-        self.sock.close()
 
     def start(self):
         """
@@ -94,271 +89,258 @@ class HTTPServer:
         # listen and respond handler
         while not self.stop_event.is_set():
             # accept new client
-            client = ssl_sock_accept(self.sock)[0]
+            client = self._accept()
+            if client is None:
+                break
 
             # create thread for new client and append it to the list
-            th = threading.Thread(target=lambda x: x, args=[client])  # TODO: this line
+            th = threading.Thread(target=self._client_thread, args=[client])
             self.client_threads.append(th)
             th.start()
 
+        # close server socket
+        self.sock.close()
 
-class HTTPServer:
-    """
-    The mighty HTTP server
-    """
-
-    def __init__(self, *, port: int, packet_size: int = 2048):
-        # ssl context
-        self.context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
-        self.context.load_cert_chain(
-            certfile=r"C:\Certbot\live\qubane.ddns.net\fullchain.pem",      # use your own path here
-            keyfile=r"C:\Certbot\live\qubane.ddns.net\privkey.pem"          # here too
-        )
-        self.context.check_hostname = False
-
-        # sockets
-        self.socket: ssl.SSLSocket = self.context.wrap_socket(
-            socket.socket(socket.AF_INET, socket.SOCK_STREAM), server_side=True)
-        self.packet_size: int = packet_size
-        self.bind_port: int = port
-
-        # list of connected clients
-        self.clients: list[socket.socket] = []
-
-    def interrupt(self, *args, **kwargs):
-        """
-        Interrupts the web server
-        """
-
-        self.socket.close()
-        for client in self.clients:
-            client.close()
-
-    def start(self):
-        """
-        Method to start the web server
-        """
-
-        # setup signaling
-
-        # bind and start listening to port
-        self.socket.bind(('', self.bind_port))
-        self.socket.listen()
-        self.socket.setblocking(False)
-
-        # start listening
-        self._listen_thread()
-
-    def _listen_thread(self):
+    def _client_thread(self, client: ssl.SSLSocket):
         """
-        Listening for new connections
+        Handles client's requests
+        :param client: client ssl socket
         """
 
-        # start listening
-        while True:
-            # try to accept new connection
+        while not self.stop_event.is_set():
             try:
-                client = ssl_sock_accept(self.socket)[0]
+                # get client's request
+                request = self._recv_request(client)
+                if request is None:
+                    break
 
-            # if socket was closed -> break
+                print(request, end="\n\n")
             except OSError as e:
                 print(e)
-                print("Closed.")
-                break
-
-            # else append to client list and create new task
-            threading.Thread(target=self.client_handle, args=[client]).start()
-
-    def client_handle(self, client: ssl.SSLSocket):
-        """
-        Handles client's connection
-        """
-
-        while True:
-            # receive request from client
-            raw_request = self._recvall(client)
-
-            if raw_request == b'':
                 break
 
-            # decode request
-            request: Request = Request.create(raw_request)
-
-            # # log request
-            # async with aiofiles.open("logs.log", "a") as f:
-            #     await f.write(f"IP: {client.getpeername()[0]}\n{request}\n\n")
-
-            threading.Thread(target=self.handle_request, args=[client, request]).start()
-
-    def handle_request(self, client: ssl.SSLSocket, request: Request):
-        # handle requests
-        try:
-            match request.type:
-                case "GET":
-                    self.handle_get_request(client, request)
-                case _:
-                    pass
-
-        # break on exception
-        except Exception as e:
-            print(e)
-
-        # # close connection (stop page loading)
-        # self._close_client(client)
-
-    @staticmethod
-    def handle_get_request(client: ssl.SSLSocket, request: Request):
-        """
-        Handles user's GET request
-        """
-
-        # get available compression methods
-        compressions = [x.strip() for x in getattr(request, "Accept-Encoding", "").split(",")]
-
-        # check if request path is in the PATH_MAP
-        if request.path in PATH_MAP:
-            # if it is -> return file from that path
-            with open(PATH_MAP[request.path]["path"], "rb") as f:
-                data = f.read()
-
-            # pre-compress data for HTML files
-            if PATH_MAP[request.path]["path"][-4:] == "html":
-                data = minimize_html(data)
-
-            # add brotli compression header (if supported)
-            headers = {}
-            if "br" in compressions:
-                headers["Content-Encoding"] = "br"
-
-            # else add gzip compression (if supported)
-            elif "gzip" in compressions:
-                headers["Content-Encoding"] = "gzip"
-
-            # send 200 response with the file to the client
-            HTTPServer._send(client, 200, data, headers)
-
-        # if it's an API request
-        elif (api_version := request.path.split("/")[1]) in API_VERSIONS:
-            data = b''
-            headers = {}
-            match api_version:
-                case "APIv1":
-                    status, data, headers = APIv1.respond(client, request)
-                case _:
-                    status = 400
-
-            # if status is not 200 -> send bad response
-            if status != 200:
-                HTTPServer._bad_response(client, status)
-                return
-
-            # send data if no error
-            HTTPServer._send(client, status, data, headers)
-
-        # in case of error, return error page
-        else:
-            HTTPServer._bad_response(client, 404)
-
-    @staticmethod
-    def _bad_response(client: ssl.SSLSocket, status_code: int):
-        """
-        Sends a bad response page to the client.
-        :param client: client
-        :param status_code: status code
-        """
-
-        with open(I_PATH_MAP["/err/response.html"]["path"], "r") as f:
-            data = f.read()
-
-        # format error response
-        data = data.format(status_code=get_response_code(status_code).decode("ascii"))
-
-        # send response to the client
-        HTTPServer._send(client, status_code, data.encode("ascii"))
-
-    @staticmethod
-    def _send(client: ssl.SSLSocket, response: int, data: bytes = None, headers: dict[str, str] = None):
-        """
-        Sends client response code + headers + data
-        :param client: client
-        :param response: response code
-        :param data: data
-        :param headers: headers to include
-        """
+        # close the connection once stop even was set or an error occurred
+        client.close()
 
-        # if data was not given
-        if data is None:
-            data = bytes()
-
-        # if headers were not given
-        if headers is None:
-            headers = dict()
-
-        # check for 'content-encoding' header
-        if headers.get("Content-Encoding") == "br":
-            data = brotli.compress(data)
-
-        elif headers.get("Content-Encoding") == "gzip":
-            data = gzip.compress(data)
-
-        # add 'Content-Length' header if not present
-        if headers.get("Content-Length") is None:
-            headers["Content-Length"] = len(data)
-
-        # format headers
-        byte_header = bytearray()
-        for key, value in headers.items():
-            byte_header += f"{key}: {value}\r\n".encode("ascii")
-
-        # send response to the client
-        client.sendall(
-            b'HTTP/1.1 ' +
-            get_response_code(response) +
-            b'\r\n' +
-            byte_header +  # if empty, we'll just get b'\r\n\r\n'
-            b'\r\n' +
-            data
-        )
-
-    def _recvall(self, client: ssl.SSLSocket) -> bytes:
+    def _recv_request(self, client: ssl.SSLSocket) -> Request | None:
         """
-        Receive All (just receives the whole message, instead of 1 packet at a time)
+        Receive request from client
+        :return: request
+        :raises: anything that recv raises
         """
 
-        # create message buffer
-        buffer: bytearray = bytearray()
-
-        # start fetching the message
-        while True:
-            try:
-                # fetch packet
-                message = ssl_sock_recv(client, self.packet_size)
-            except OSError:
-                break
-
-            # that happens when user stops loading the page
-            if message == b'':
-                break
-
-            # append fetched message to the buffer
-            buffer += message
-
-            # check for EoF
+        buffer = bytearray()
+        while not self.stop_event.is_set():
+            buffer += client.recv(self.buf_len)
             if buffer[-4:] == b'\r\n\r\n':
-                # return the received message
-                return buffer
+                return Request.create(buffer)
+        return None
 
-        # return empty buffer on error
-        return b''
-
-    def _close_client(self, client: socket.socket):
+    def _accept(self) -> ssl.SSLSocket | None:
         """
-        Closes a client
+        socket.accept, but for more graceful closing
         """
 
-        client.close()
-        if client in self.clients:
-            self.clients.remove(client)
+        while not self.stop_event.is_set():
+            try:
+                return self.sock.accept()[0]
+            except BlockingIOError:
+                time.sleep(0.001)
+        return None
+
+
+# class HTTPServer:
+#
+#     def client_handle(self, client: ssl.SSLSocket):
+#         """
+#         Handles client's connection
+#         """
+#
+#         while True:
+#             # receive request from client
+#             raw_request = self._recvall(client)
+#
+#             if raw_request == b'':
+#                 break
+#
+#             # decode request
+#             request: Request = Request.create(raw_request)
+#
+#             # # log request
+#             # async with aiofiles.open("logs.log", "a") as f:
+#             #     await f.write(f"IP: {client.getpeername()[0]}\n{request}\n\n")
+#
+#             threading.Thread(target=self.handle_request, args=[client, request]).start()
+#
+#     def handle_request(self, client: ssl.SSLSocket, request: Request):
+#         # handle requests
+#         try:
+#             match request.type:
+#                 case "GET":
+#                     self.handle_get_request(client, request)
+#                 case _:
+#                     pass
+#
+#         # break on exception
+#         except Exception as e:
+#             print(e)
+#
+#         # # close connection (stop page loading)
+#         # self._close_client(client)
+#
+#     @staticmethod
+#     def handle_get_request(client: ssl.SSLSocket, request: Request):
+#         """
+#         Handles user's GET request
+#         """
+#
+#         # get available compression methods
+#         compressions = [x.strip() for x in getattr(request, "Accept-Encoding", "").split(",")]
+#
+#         # check if request path is in the PATH_MAP
+#         if request.path in PATH_MAP:
+#             # if it is -> return file from that path
+#             with open(PATH_MAP[request.path]["path"], "rb") as f:
+#                 data = f.read()
+#
+#             # pre-compress data for HTML files
+#             if PATH_MAP[request.path]["path"][-4:] == "html":
+#                 data = minimize_html(data)
+#
+#             # add brotli compression header (if supported)
+#             headers = {}
+#             if "br" in compressions:
+#                 headers["Content-Encoding"] = "br"
+#
+#             # else add gzip compression (if supported)
+#             elif "gzip" in compressions:
+#                 headers["Content-Encoding"] = "gzip"
+#
+#             # send 200 response with the file to the client
+#             HTTPServer._send(client, 200, data, headers)
+#
+#         # if it's an API request
+#         elif (api_version := request.path.split("/")[1]) in API_VERSIONS:
+#             data = b''
+#             headers = {}
+#             match api_version:
+#                 case "APIv1":
+#                     status, data, headers = APIv1.respond(client, request)
+#                 case _:
+#                     status = 400
+#
+#             # if status is not 200 -> send bad response
+#             if status != 200:
+#                 HTTPServer._bad_response(client, status)
+#                 return
+#
+#             # send data if no error
+#             HTTPServer._send(client, status, data, headers)
+#
+#         # in case of error, return error page
+#         else:
+#             HTTPServer._bad_response(client, 404)
+#
+#     @staticmethod
+#     def _bad_response(client: ssl.SSLSocket, status_code: int):
+#         """
+#         Sends a bad response page to the client.
+#         :param client: client
+#         :param status_code: status code
+#         """
+#
+#         with open(I_PATH_MAP["/err/response.html"]["path"], "r") as f:
+#             data = f.read()
+#
+#         # format error response
+#         data = data.format(status_code=get_response_code(status_code).decode("ascii"))
+#
+#         # send response to the client
+#         HTTPServer._send(client, status_code, data.encode("ascii"))
+#
+#     @staticmethod
+#     def _send(client: ssl.SSLSocket, response: int, data: bytes = None, headers: dict[str, str] = None):
+#         """
+#         Sends client response code + headers + data
+#         :param client: client
+#         :param response: response code
+#         :param data: data
+#         :param headers: headers to include
+#         """
+#
+#         # if data was not given
+#         if data is None:
+#             data = bytes()
+#
+#         # if headers were not given
+#         if headers is None:
+#             headers = dict()
+#
+#         # check for 'content-encoding' header
+#         if headers.get("Content-Encoding") == "br":
+#             data = brotli.compress(data)
+#
+#         elif headers.get("Content-Encoding") == "gzip":
+#             data = gzip.compress(data)
+#
+#         # add 'Content-Length' header if not present
+#         if headers.get("Content-Length") is None:
+#             headers["Content-Length"] = len(data)
+#
+#         # format headers
+#         byte_header = bytearray()
+#         for key, value in headers.items():
+#             byte_header += f"{key}: {value}\r\n".encode("ascii")
+#
+#         # send response to the client
+#         client.sendall(
+#             b'HTTP/1.1 ' +
+#             get_response_code(response) +
+#             b'\r\n' +
+#             byte_header +  # if empty, we'll just get b'\r\n\r\n'
+#             b'\r\n' +
+#             data
+#         )
+#
+#     def _recvall(self, client: ssl.SSLSocket) -> bytes:
+#         """
+#         Receive All (just receives the whole message, instead of 1 packet at a time)
+#         """
+#
+#         # create message buffer
+#         buffer: bytearray = bytearray()
+#
+#         # start fetching the message
+#         while True:
+#             try:
+#                 # fetch packet
+#                 message = ssl_sock_recv(client, self.packet_size)
+#             except OSError:
+#                 break
+#
+#             # that happens when user stops loading the page
+#             if message == b'':
+#                 break
+#
+#             # append fetched message to the buffer
+#             buffer += message
+#
+#             # check for EoF
+#             if buffer[-4:] == b'\r\n\r\n':
+#                 # return the received message
+#                 return buffer
+#
+#         # return empty buffer on error
+#         return b''
+#
+#     def _close_client(self, client: socket.socket):
+#         """
+#         Closes a client
+#         """
+#
+#         client.close()
+#         if client in self.clients:
+#             self.clients.remove(client)
 
 
 def main():