diff options
| author | Qubik <89706156+UltraQbik@users.noreply.github.com> | 2024-08-23 20:27:27 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-23 20:27:27 +0200 |
| commit | 9aca00d7265b9d05f908b4201a03b7b0808c5ca1 (patch) | |
| tree | 47346be37cc68fc18aa8e5603eef58a48d9b2e04 | |
| parent | a96b13f6816ca0657d9f65097b97d0e87e1a0366 (diff) | |
| parent | fd8c46cac1c914851613cac425d2afe68d360d9d (diff) | |
| download | httpy-9aca00d7265b9d05f908b4201a03b7b0808c5ca1.tar.gz httpy-9aca00d7265b9d05f908b4201a03b7b0808c5ca1.zip | |
Merge pull request #2 from UltraQbik/threading-rewrite
Threading rewrite
| -rw-r--r-- | main.py | 532 | ||||
| -rw-r--r-- | src/APIv1.py | 22 | ||||
| -rw-r--r-- | src/request.py | 25 | ||||
| -rw-r--r-- | src/status_code.py | 34 |
4 files changed, 343 insertions, 270 deletions
diff --git a/main.py b/main.py index 0b9f991..3bbd9d4 100644 --- a/main.py +++ b/main.py @@ -5,328 +5,344 @@ The mighty silly webserver written in python for no good reason import ssl import gzip +import time import socket import brotli import signal import threading from src import APIv1 -from src.socks import * -from src.request import Request +from src.request import * +from src.status_code import * from src.minimizer import minimize_html # path mapping PATH_MAP = { - "/": - {"path": "www/index.html", - "compress": True}, - "/index.html": - {"path": "www/index.html", - "compress": True}, - "/robots.txt": - {"path": "www/robots.txt", - "compress": False}, - "/favicon.ico": - {"path": "www/favicon.ico", - "compress": False}, - "/css/styles.css": - {"path": "css/styles.css", - "compress": True}, - "/about": - {"path": "www/about.html", - "compress": True}, - "/test": - {"path": "www/test.html", - "compress": True}, + "/": {"path": "www/index.html"}, + "/index.html": {"path": "www/index.html"}, + "/robots.txt": {"path": "www/robots.txt"}, + "/favicon.ico": {"path": "www/favicon.ico"}, + "/css/styles.css": {"path": "css/styles.css"}, + "/about": {"path": "www/about.html"}, + "/test": {"path": "www/test.html"}, } # API API_VERSIONS = { - "APIv1" + "APIv1": {"supported": True} } # internal path map I_PATH_MAP = { - "/err/response.html": {"path": "www/err/response.html"} + "/err/response": {"path": "www/err/response.html"} } -def get_response_code(code: int) -> bytes: - match code: - case 200: - return b'200 OK' - case 400: - return b'400 Bad Request' - case 401: - return b'401 Unauthorized' - case 403: - return b'403 Forbidden' - case 404: - return b'404 Not Found' - case 6969: - return b'6969 UwU' - case _: # in any other case return bad request response - return get_response_code(400) - - class HTTPServer: """ - The mighty HTTP server + The mightier HTTP server! + Now uses threading """ def __init__(self, *, port: int, packet_size: int = 2048): - # ssl context - self.context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - self.context.load_cert_chain( - certfile=r"C:\Certbot\live\qubane.ddns.net\fullchain.pem", # use your own path here - keyfile=r"C:\Certbot\live\qubane.ddns.net\privkey.pem" # here too - ) - self.context.check_hostname = False - - # sockets - self.socket: ssl.SSLSocket = self.context.wrap_socket( - socket.socket(socket.AF_INET, socket.SOCK_STREAM), server_side=True) - self.packet_size: int = packet_size - self.bind_port: int = port - - # list of connected clients - self.clients: list[socket.socket] = [] - - def interrupt(self, *args, **kwargs): + # SSL context + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + context.check_hostname = False + context.load_cert_chain( + certfile=r"C:\Certbot\live\qubane.ddns.net\fullchain.pem", # use your own path here + keyfile=r"C:\Certbot\live\qubane.ddns.net\privkey.pem") # here too + + # Sockets + self.sock: ssl.SSLSocket = context.wrap_socket( + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + server_side=True) + self.buf_len: int = packet_size + self.port: int = port + + # client thread list and server thread + self.client_threads: list[threading.Thread] = [] + + # add signaling + self.stop_event = threading.Event() + signal.signal(signal.SIGINT, self._signal_interrupt) + + def _signal_interrupt(self, *args): """ - Interrupts the web server + Checks for CTRL+C keyboard interrupt, to properly stop the HTTP server """ - self.socket.close() - for client in self.clients: - client.close() + # stop all threads + self.stop_event.set() + for thread in self.client_threads: + thread.join() def start(self): """ Method to start the web server """ - # setup signaling - signal.signal(signal.SIGINT, self.interrupt) - # bind and start listening to port - self.socket.bind(('', self.bind_port)) - self.socket.listen() - self.socket.setblocking(False) + self.sock.bind(('', self.port)) + self.sock.setblocking(False) + self.sock.listen() + + # listen and respond handler + while not self.stop_event.is_set(): + # accept new client + client = self._accept() + if client is None: + break + + # create thread for new client and append it to the list + th = threading.Thread(target=self._client_thread, args=[client]) + self.client_threads.append(th) + th.start() - # start listening - self._listen_thread() + # close server socket + self.sock.close() - def _listen_thread(self): + def _client_thread(self, client: ssl.SSLSocket): """ - Listening for new connections + Handles getting client's requests + :param client: client ssl socket """ - # start listening - while True: - # try to accept new connection + # client.settimeout(5) + while not self.stop_event.is_set(): try: - client = ssl_sock_accept(self.socket)[0] - - # if socket was closed -> break - except OSError as e: - print(e) - print("Closed.") + # get client's request + request = self._recv_request(client) + if request is None: + break + + threading.Thread(target=self._client_request_handler, args=[client, request], daemon=True).start() + except TimeoutError: + print("Client timeout") break - - # else append to client list and create new task - threading.Thread(target=self.client_handle, args=[client]).start() - - def client_handle(self, client: ssl.SSLSocket): - """ - Handles client's connection - """ - - while True: - # receive request from client - raw_request = self._recvall(client) - - if raw_request == b'': + except Exception as e: + print(e) break - # decode request - request: Request = Request.create(raw_request) - - # # log request - # async with aiofiles.open("logs.log", "a") as f: - # await f.write(f"IP: {client.getpeername()[0]}\n{request}\n\n") - - threading.Thread(target=self.handle_request, args=[client, request]).start() - - def handle_request(self, client: ssl.SSLSocket, request: Request): - # handle requests - try: - match request.type: - case "GET": - self.handle_get_request(client, request) - case _: - pass - - # break on exception - except Exception as e: - print(e) - - # # close connection (stop page loading) - # self._close_client(client) + # close the connection once stop even was set or an error occurred + client.close() - @staticmethod - def handle_get_request(client: ssl.SSLSocket, request: Request): + def _client_request_handler(self, client: ssl.SSLSocket, request: Request): """ - Handles user's GET request + Handles responses to client's requests + :param client: client + :param request: client's request """ - # get available compression methods - compressions = [x.strip() for x in getattr(request, "Accept-Encoding", "").split(",")] - - # check if request path is in the PATH_MAP - if request.path in PATH_MAP: - # if it is -> return file from that path - with open(PATH_MAP[request.path]["path"], "rb") as f: - data = f.read() - - # pre-compress data for HTML files - if PATH_MAP[request.path]["path"][-4:] == "html": - data = minimize_html(data) - - # add brotli compression header (if supported) - headers = {} - if "br" in compressions: - headers["Content-Encoding"] = "br" - - # else add gzip compression (if supported) - elif "gzip" in compressions: - headers["Content-Encoding"] = "gzip" - - # send 200 response with the file to the client - HTTPServer._send(client, 200, data, headers) - - # if it's an API request - elif (api_version := request.path.split("/")[1]) in API_VERSIONS: - data = b'' - headers = {} - match api_version: - case "APIv1": - status, data, headers = APIv1.respond(client, request) - case _: - status = 400 - - # if status is not 200 -> send bad response - if status != 200: - HTTPServer._bad_response(client, status) - return - - # send data if no error - HTTPServer._send(client, status, data, headers) - - # in case of error, return error page - else: - HTTPServer._bad_response(client, 404) - - @staticmethod - def _bad_response(client: ssl.SSLSocket, status_code: int): + match request.type: + case "GET" | "HEAD": + response = self._handle_get(client, request) + # case "POST": # Not Implemented + # response = self._handle_post(client, request) + case _: + with open(I_PATH_MAP["/err/response"]["path"], "r", encoding="ascii") as file: + data = file.read().format(status_code=str(STATUS_CODE_NOT_FOUND)).encode("ascii") + response = Response(data, STATUS_CODE_NOT_FOUND) + + # process header data + if response.headers.get("Content-Encoding") is None and response.compress: + supported_compressions = [x.strip() for x in getattr(request, "Accept-Encoding", "").split(",")] + if "br" in supported_compressions: + response.headers["Content-Encoding"] = "br" + response.data = brotli.compress(response.data) + elif "gzip" in supported_compressions: + response.headers["Content-Encoding"] = "gzip" + response.data = gzip.compress(response.data) + if response.headers.get("Content-Length") is None: + response.headers["Content-Length"] = len(response.data) + if response.headers.get("Connection") is None: + response.headers["Connection"] = "close" + + # generate basic message + message = b'HTTP/1.1 ' + response.status.__bytes__() + b'\r\n' + for key, value in response.headers.items(): + message += f"{key}: {value}\r\n".encode("ascii") + message += b'\r\n' + response.data + + # send message + client.sendall(message) + + def _handle_get(self, client: ssl.SSLSocket, request: Request) -> Response: """ - Sends a bad response page to the client. - :param client: client - :param status_code: status code + Handles GET / HEAD requests from a client """ - with open(I_PATH_MAP["/err/response.html"]["path"], "r") as f: - data = f.read() - - # format error response - data = data.format(status_code=get_response_code(status_code).decode("ascii")) - - # send response to the client - HTTPServer._send(client, status_code, data.encode("ascii")) - - @staticmethod - def _send(client: ssl.SSLSocket, response: int, data: bytes = None, headers: dict[str, str] = None): + split_path = request.path.split("/", maxsplit=16)[1:] + if request.path in PATH_MAP: # assume browser + filepath = PATH_MAP[request.path]["path"] + with open(filepath, "rb") as file: + data = file.read() + + if request.type == "GET": + return Response(data, STATUS_CODE_OK) + elif request.type == "HEAD": + return Response(b'', STATUS_CODE_OK, {"Content-Length": len(data)}) + else: + raise TypeError("Called GET handler for non-GET request") + + elif len(split_path) >= 2 and split_path[0] in API_VERSIONS: # assume script + # unsupported API version + if not API_VERSIONS[split_path[0]]: + if request.type == "GET" or request.type == "HEAD": + return Response(b'', STATUS_CODE_BAD_REQUEST) + else: + raise TypeError("Called GET handler for non-GET request") + + return APIv1.api_call(client, request) + + else: # assume browser + with open(I_PATH_MAP["/err/response"]["path"], "r", encoding="ascii") as file: + data = file.read() + data = data.format(status_code=str(STATUS_CODE_NOT_FOUND)).encode("ascii") + return Response(data, STATUS_CODE_NOT_FOUND) + + def _handle_post(self, client: ssl.SSLSocket, request: Request) -> Response: """ - Sends client response code + headers + data - :param client: client - :param response: response code - :param data: data - :param headers: headers to include + Handles POSt request from a client """ - # if data was not given - if data is None: - data = bytes() - - # if headers were not given - if headers is None: - headers = dict() - - # check for 'content-encoding' header - if headers.get("Content-Encoding") == "br": - data = brotli.compress(data) - - elif headers.get("Content-Encoding") == "gzip": - data = gzip.compress(data) - - # add 'Content-Length' header if not present - if headers.get("Content-Length") is None: - headers["Content-Length"] = len(data) - - # format headers - byte_header = bytearray() - for key, value in headers.items(): - byte_header += f"{key}: {value}\r\n".encode("ascii") - - # send response to the client - client.sendall( - b'HTTP/1.1 ' + - get_response_code(response) + - b'\r\n' + - byte_header + # if empty, we'll just get b'\r\n\r\n' - b'\r\n' + - data - ) - - def _recvall(self, client: ssl.SSLSocket) -> bytes: + def _recv_request(self, client: ssl.SSLSocket) -> Request | None: """ - Receive All (just receives the whole message, instead of 1 packet at a time) + Receive request from client + :return: request + :raises: anything that recv raises """ - # create message buffer - buffer: bytearray = bytearray() - - # start fetching the message - while True: - try: - # fetch packet - message = ssl_sock_recv(client, self.packet_size) - except OSError: - break - - # that happens when user stops loading the page - if message == b'': + buffer = bytearray() + while not self.stop_event.is_set(): + msg = client.recv(self.buf_len) + if len(msg) == 0: break - - # append fetched message to the buffer - buffer += message - - # check for EoF + buffer += msg if buffer[-4:] == b'\r\n\r\n': - # return the received message - return buffer - - # return empty buffer on error - return b'' + return Request.create(buffer) + return None - def _close_client(self, client: socket.socket): + def _accept(self) -> ssl.SSLSocket | None: """ - Closes a client + socket.accept, but for more graceful closing """ - client.close() - if client in self.clients: - self.clients.remove(client) + while not self.stop_event.is_set(): + try: + return self.sock.accept()[0] + except BlockingIOError: + time.sleep(0.001) + return None + + +# class HTTPServer: +# +# def client_handle(self, client: ssl.SSLSocket): +# """ +# Handles client's connection +# """ +# +# while True: +# # receive request from client +# raw_request = self._recvall(client) +# +# if raw_request == b'': +# break +# +# # decode request +# request: Request = Request.create(raw_request) +# +# # # log request +# # async with aiofiles.open("logs.log", "a") as f: +# # await f.write(f"IP: {client.getpeername()[0]}\n{request}\n\n") +# +# threading.Thread(target=self.handle_request, args=[client, request]).start() +# +# def handle_request(self, client: ssl.SSLSocket, request: Request): +# # handle requests +# try: +# match request.type: +# case "GET": +# self.handle_get_request(client, request) +# case _: +# pass +# +# # break on exception +# except Exception as e: +# print(e) +# +# # # close connection (stop page loading) +# # self._close_client(client) +# +# @staticmethod +# def handle_get_request(client: ssl.SSLSocket, request: Request): +# """ +# Handles user's GET request +# """ +# +# # get available compression methods +# compressions = [x.strip() for x in getattr(request, "Accept-Encoding", "").split(",")] +# +# # check if request path is in the PATH_MAP +# if request.path in PATH_MAP: +# # if it is -> return file from that path +# with open(PATH_MAP[request.path]["path"], "rb") as f: +# data = f.read() +# +# # pre-compress data for HTML files +# if PATH_MAP[request.path]["path"][-4:] == "html": +# data = minimize_html(data) +# +# # add brotli compression header (if supported) +# headers = {} +# if "br" in compressions: +# headers["Content-Encoding"] = "br" +# +# # else add gzip compression (if supported) +# elif "gzip" in compressions: +# headers["Content-Encoding"] = "gzip" +# +# # send 200 response with the file to the client +# HTTPServer._send(client, 200, data, headers) +# +# # if it's an API request +# elif (api_version := request.path.split("/")[1]) in API_VERSIONS: +# data = b'' +# headers = {} +# match api_version: +# case "APIv1": +# status, data, headers = APIv1.respond(client, request) +# case _: +# status = 400 +# +# # if status is not 200 -> send bad response +# if status != 200: +# HTTPServer._bad_response(client, status) +# return +# +# # send data if no error +# HTTPServer._send(client, status, data, headers) +# +# # in case of error, return error page +# else: +# HTTPServer._bad_response(client, 404) +# +# @staticmethod +# def _bad_response(client: ssl.SSLSocket, status_code: int): +# """ +# Sends a bad response page to the client. +# :param client: client +# :param status_code: status code +# """ +# +# with open(I_PATH_MAP["/err/response.html"]["path"], "r") as f: +# data = f.read() +# +# # format error response +# data = data.format(status_code=get_response_code(status_code).decode("ascii")) +# +# # send response to the client +# HTTPServer._send(client, status_code, data.encode("ascii")) def main(): diff --git a/src/APIv1.py b/src/APIv1.py index db8c4f3..f163c5b 100644 --- a/src/APIv1.py +++ b/src/APIv1.py @@ -1,6 +1,6 @@ import random -from ssl import SSLSocket -from src.request import Request +from src.request import * +from src.status_code import * API_FILE_RANDOM_MIN_SIZE_LIMIT = 1 @@ -59,29 +59,27 @@ def decode_size(size: str) -> int: return size -def respond(client: SSLSocket, request: Request) -> tuple[int, bytes, dict]: +def api_call(client: SSLSocket, request: Request) -> Response: """ Respond to clients API request """ # decode API request - split_path = request.path.split("/") - api_level1 = split_path[2] - api_request = split_path[3] + split_path = request.path.split("/", maxsplit=16)[1:] # do something with it (oh god) - if api_level1 == "file": - if api_request == "random": + if len(split_path) > 1 and split_path[1] == "file": + if len(split_path) > 2 and split_path[2] == "random": # get size size_str = request.path_args.get("size", "16mib") size = decode_size(size_str) # check size if size < API_FILE_RANDOM_MIN_SIZE_LIMIT or size > API_FILE_RANDOM_MAX_SIZE_LIMIT: - return 400, b'', {} + return Response(b'', STATUS_CODE_BAD_REQUEST) - return 200, random_data_gen(size), {} + return Response(random_data_gen(size), STATUS_CODE_OK, compress=False) else: - return 400, b'', {} + return Response(b'', STATUS_CODE_BAD_REQUEST) else: - return 400, b'', {} + return Response(b'', STATUS_CODE_BAD_REQUEST) diff --git a/src/request.py b/src/request.py index 0437dfd..003783a 100644 --- a/src/request.py +++ b/src/request.py @@ -1,3 +1,8 @@ +from typing import Any +from ssl import SSLSocket +from src.status_code import StatusCode + + class Request: """ Just a request @@ -54,3 +59,23 @@ class Request: def __str__(self): return '\n'.join([f"{key}: {val}" for key, val in self.__dict__.items()]) + + +class Response: + """ + Server response + """ + + def __init__(self, data: bytes, status: StatusCode, headers: dict[str, Any] = None, **kwargs): + """ + + :param data: response data + :param status: response status code + :param headers: headers to include + :param kwarg: compress - whether to compress data or not + """ + + self.data: bytes = data + self.status: StatusCode = status + self.headers: dict[str, Any] = headers if headers is not None else dict() + self.compress: bool = kwargs.get("compress", True) diff --git a/src/status_code.py b/src/status_code.py new file mode 100644 index 0000000..a63712c --- /dev/null +++ b/src/status_code.py @@ -0,0 +1,34 @@ +class StatusCode: + """ + HTML status code + """ + + def __init__(self, code: int, message: str): + self._code: int = code + self._message: str = message + + def __bytes__(self): + return f"{self._code} {self._message}".encode("ascii") + + def __str__(self): + return f"{self._code} {self._message}" + + @property + def code(self): + return self._code + + @property + def message(self): + return self._message + + +# Status codes! +STATUS_CODE_OK = StatusCode(200, "OK") +STATUS_CODE_BAD_REQUEST = StatusCode(400, "Bad Request") +STATUS_CODE_UNAUTHORIZED = StatusCode(401, "Unauthorized") +STATUS_CODE_FORBIDDEN = StatusCode(403, "Forbidden") +STATUS_CODE_NOT_FOUND = StatusCode(404, "Not Found") +STATUS_CODE_PAYLOAD_TOO_LARGE = StatusCode(413, "Payload Too Large") +STATUS_CODE_URI_TOO_LONG = StatusCode(414, "URI Too Long") +STATUS_CODE_IM_A_TEAPOT = StatusCode(418, "I'm a teapot") # I followed mozilla's dev page, it was there +STATUS_CODE_FUNNY_NUMBER = StatusCode(6969, "UwU") |