diff options
| author | UltraQbik <no1skill@yandex.ru> | 2024-08-23 15:44:06 +0300 |
|---|---|---|
| committer | UltraQbik <no1skill@yandex.ru> | 2024-08-23 15:44:06 +0300 |
| commit | 46a76eff285459d25f142d2ce2628425d4e69e94 (patch) | |
| tree | f8019e5cd8e53a72541102107521106478330bbe | |
| parent | c95930a8d1e9724a0b720c6fe4e648d0ee267800 (diff) | |
| download | httpy-46a76eff285459d25f142d2ce2628425d4e69e94.tar.gz httpy-46a76eff285459d25f142d2ce2628425d4e69e94.zip | |
More graceful shutdown
There is a weird issue when a user tries to load a page, doesn't get a response (because that wasn't implemented yet), and when the webserver is attempted to be shutdown, it still tries to fetch user's request
| -rw-r--r-- | main.py | 480 |
1 files changed, 231 insertions, 249 deletions
diff --git a/main.py b/main.py index 0ab634f..4edd2ba 100644 --- a/main.py +++ b/main.py @@ -4,16 +4,16 @@ The mighty silly webserver written in python for no good reason import ssl +import time import gzip import socket import brotli import signal import threading from src import APIv1 -from src.socks import * +from src.status_code import * from src.request import Request from src.minimizer import minimize_html -from src.status_code import * # path mapping @@ -56,12 +56,11 @@ class HTTPServer: self.sock: ssl.SSLSocket = context.wrap_socket( socket.socket(socket.AF_INET, socket.SOCK_STREAM), server_side=True) - self.packet_size: int = packet_size + self.buf_len: int = packet_size self.port: int = port # client thread list and server thread self.client_threads: list[threading.Thread] = [] - self.server_thread: threading.Thread | None = None # add signaling self.stop_event = threading.Event() @@ -76,10 +75,6 @@ class HTTPServer: self.stop_event.set() for thread in self.client_threads: thread.join() - self.server_thread.join() - - # close server socket - self.sock.close() def start(self): """ @@ -94,271 +89,258 @@ class HTTPServer: # listen and respond handler while not self.stop_event.is_set(): # accept new client - client = ssl_sock_accept(self.sock)[0] + client = self._accept() + if client is None: + break # create thread for new client and append it to the list - th = threading.Thread(target=lambda x: x, args=[client]) # TODO: this line + th = threading.Thread(target=self._client_thread, args=[client]) self.client_threads.append(th) th.start() + # close server socket + self.sock.close() -class HTTPServer: - """ - The mighty HTTP server - """ - - def __init__(self, *, port: int, packet_size: int = 2048): - # ssl context - self.context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - self.context.load_cert_chain( - certfile=r"C:\Certbot\live\qubane.ddns.net\fullchain.pem", # use your own path here - keyfile=r"C:\Certbot\live\qubane.ddns.net\privkey.pem" # here too - ) - self.context.check_hostname = False - - # sockets - self.socket: ssl.SSLSocket = self.context.wrap_socket( - socket.socket(socket.AF_INET, socket.SOCK_STREAM), server_side=True) - self.packet_size: int = packet_size - self.bind_port: int = port - - # list of connected clients - self.clients: list[socket.socket] = [] - - def interrupt(self, *args, **kwargs): - """ - Interrupts the web server - """ - - self.socket.close() - for client in self.clients: - client.close() - - def start(self): - """ - Method to start the web server - """ - - # setup signaling - - # bind and start listening to port - self.socket.bind(('', self.bind_port)) - self.socket.listen() - self.socket.setblocking(False) - - # start listening - self._listen_thread() - - def _listen_thread(self): + def _client_thread(self, client: ssl.SSLSocket): """ - Listening for new connections + Handles client's requests + :param client: client ssl socket """ - # start listening - while True: - # try to accept new connection + while not self.stop_event.is_set(): try: - client = ssl_sock_accept(self.socket)[0] + # get client's request + request = self._recv_request(client) + if request is None: + break - # if socket was closed -> break + print(request, end="\n\n") except OSError as e: print(e) - print("Closed.") - break - - # else append to client list and create new task - threading.Thread(target=self.client_handle, args=[client]).start() - - def client_handle(self, client: ssl.SSLSocket): - """ - Handles client's connection - """ - - while True: - # receive request from client - raw_request = self._recvall(client) - - if raw_request == b'': break - # decode request - request: Request = Request.create(raw_request) - - # # log request - # async with aiofiles.open("logs.log", "a") as f: - # await f.write(f"IP: {client.getpeername()[0]}\n{request}\n\n") - - threading.Thread(target=self.handle_request, args=[client, request]).start() - - def handle_request(self, client: ssl.SSLSocket, request: Request): - # handle requests - try: - match request.type: - case "GET": - self.handle_get_request(client, request) - case _: - pass - - # break on exception - except Exception as e: - print(e) - - # # close connection (stop page loading) - # self._close_client(client) - - @staticmethod - def handle_get_request(client: ssl.SSLSocket, request: Request): - """ - Handles user's GET request - """ - - # get available compression methods - compressions = [x.strip() for x in getattr(request, "Accept-Encoding", "").split(",")] - - # check if request path is in the PATH_MAP - if request.path in PATH_MAP: - # if it is -> return file from that path - with open(PATH_MAP[request.path]["path"], "rb") as f: - data = f.read() - - # pre-compress data for HTML files - if PATH_MAP[request.path]["path"][-4:] == "html": - data = minimize_html(data) - - # add brotli compression header (if supported) - headers = {} - if "br" in compressions: - headers["Content-Encoding"] = "br" - - # else add gzip compression (if supported) - elif "gzip" in compressions: - headers["Content-Encoding"] = "gzip" - - # send 200 response with the file to the client - HTTPServer._send(client, 200, data, headers) - - # if it's an API request - elif (api_version := request.path.split("/")[1]) in API_VERSIONS: - data = b'' - headers = {} - match api_version: - case "APIv1": - status, data, headers = APIv1.respond(client, request) - case _: - status = 400 - - # if status is not 200 -> send bad response - if status != 200: - HTTPServer._bad_response(client, status) - return - - # send data if no error - HTTPServer._send(client, status, data, headers) - - # in case of error, return error page - else: - HTTPServer._bad_response(client, 404) - - @staticmethod - def _bad_response(client: ssl.SSLSocket, status_code: int): - """ - Sends a bad response page to the client. - :param client: client - :param status_code: status code - """ - - with open(I_PATH_MAP["/err/response.html"]["path"], "r") as f: - data = f.read() - - # format error response - data = data.format(status_code=get_response_code(status_code).decode("ascii")) - - # send response to the client - HTTPServer._send(client, status_code, data.encode("ascii")) - - @staticmethod - def _send(client: ssl.SSLSocket, response: int, data: bytes = None, headers: dict[str, str] = None): - """ - Sends client response code + headers + data - :param client: client - :param response: response code - :param data: data - :param headers: headers to include - """ + # close the connection once stop even was set or an error occurred + client.close() - # if data was not given - if data is None: - data = bytes() - - # if headers were not given - if headers is None: - headers = dict() - - # check for 'content-encoding' header - if headers.get("Content-Encoding") == "br": - data = brotli.compress(data) - - elif headers.get("Content-Encoding") == "gzip": - data = gzip.compress(data) - - # add 'Content-Length' header if not present - if headers.get("Content-Length") is None: - headers["Content-Length"] = len(data) - - # format headers - byte_header = bytearray() - for key, value in headers.items(): - byte_header += f"{key}: {value}\r\n".encode("ascii") - - # send response to the client - client.sendall( - b'HTTP/1.1 ' + - get_response_code(response) + - b'\r\n' + - byte_header + # if empty, we'll just get b'\r\n\r\n' - b'\r\n' + - data - ) - - def _recvall(self, client: ssl.SSLSocket) -> bytes: + def _recv_request(self, client: ssl.SSLSocket) -> Request | None: """ - Receive All (just receives the whole message, instead of 1 packet at a time) + Receive request from client + :return: request + :raises: anything that recv raises """ - # create message buffer - buffer: bytearray = bytearray() - - # start fetching the message - while True: - try: - # fetch packet - message = ssl_sock_recv(client, self.packet_size) - except OSError: - break - - # that happens when user stops loading the page - if message == b'': - break - - # append fetched message to the buffer - buffer += message - - # check for EoF + buffer = bytearray() + while not self.stop_event.is_set(): + buffer += client.recv(self.buf_len) if buffer[-4:] == b'\r\n\r\n': - # return the received message - return buffer + return Request.create(buffer) + return None - # return empty buffer on error - return b'' - - def _close_client(self, client: socket.socket): + def _accept(self) -> ssl.SSLSocket | None: """ - Closes a client + socket.accept, but for more graceful closing """ - client.close() - if client in self.clients: - self.clients.remove(client) + while not self.stop_event.is_set(): + try: + return self.sock.accept()[0] + except BlockingIOError: + time.sleep(0.001) + return None + + +# class HTTPServer: +# +# def client_handle(self, client: ssl.SSLSocket): +# """ +# Handles client's connection +# """ +# +# while True: +# # receive request from client +# raw_request = self._recvall(client) +# +# if raw_request == b'': +# break +# +# # decode request +# request: Request = Request.create(raw_request) +# +# # # log request +# # async with aiofiles.open("logs.log", "a") as f: +# # await f.write(f"IP: {client.getpeername()[0]}\n{request}\n\n") +# +# threading.Thread(target=self.handle_request, args=[client, request]).start() +# +# def handle_request(self, client: ssl.SSLSocket, request: Request): +# # handle requests +# try: +# match request.type: +# case "GET": +# self.handle_get_request(client, request) +# case _: +# pass +# +# # break on exception +# except Exception as e: +# print(e) +# +# # # close connection (stop page loading) +# # self._close_client(client) +# +# @staticmethod +# def handle_get_request(client: ssl.SSLSocket, request: Request): +# """ +# Handles user's GET request +# """ +# +# # get available compression methods +# compressions = [x.strip() for x in getattr(request, "Accept-Encoding", "").split(",")] +# +# # check if request path is in the PATH_MAP +# if request.path in PATH_MAP: +# # if it is -> return file from that path +# with open(PATH_MAP[request.path]["path"], "rb") as f: +# data = f.read() +# +# # pre-compress data for HTML files +# if PATH_MAP[request.path]["path"][-4:] == "html": +# data = minimize_html(data) +# +# # add brotli compression header (if supported) +# headers = {} +# if "br" in compressions: +# headers["Content-Encoding"] = "br" +# +# # else add gzip compression (if supported) +# elif "gzip" in compressions: +# headers["Content-Encoding"] = "gzip" +# +# # send 200 response with the file to the client +# HTTPServer._send(client, 200, data, headers) +# +# # if it's an API request +# elif (api_version := request.path.split("/")[1]) in API_VERSIONS: +# data = b'' +# headers = {} +# match api_version: +# case "APIv1": +# status, data, headers = APIv1.respond(client, request) +# case _: +# status = 400 +# +# # if status is not 200 -> send bad response +# if status != 200: +# HTTPServer._bad_response(client, status) +# return +# +# # send data if no error +# HTTPServer._send(client, status, data, headers) +# +# # in case of error, return error page +# else: +# HTTPServer._bad_response(client, 404) +# +# @staticmethod +# def _bad_response(client: ssl.SSLSocket, status_code: int): +# """ +# Sends a bad response page to the client. +# :param client: client +# :param status_code: status code +# """ +# +# with open(I_PATH_MAP["/err/response.html"]["path"], "r") as f: +# data = f.read() +# +# # format error response +# data = data.format(status_code=get_response_code(status_code).decode("ascii")) +# +# # send response to the client +# HTTPServer._send(client, status_code, data.encode("ascii")) +# +# @staticmethod +# def _send(client: ssl.SSLSocket, response: int, data: bytes = None, headers: dict[str, str] = None): +# """ +# Sends client response code + headers + data +# :param client: client +# :param response: response code +# :param data: data +# :param headers: headers to include +# """ +# +# # if data was not given +# if data is None: +# data = bytes() +# +# # if headers were not given +# if headers is None: +# headers = dict() +# +# # check for 'content-encoding' header +# if headers.get("Content-Encoding") == "br": +# data = brotli.compress(data) +# +# elif headers.get("Content-Encoding") == "gzip": +# data = gzip.compress(data) +# +# # add 'Content-Length' header if not present +# if headers.get("Content-Length") is None: +# headers["Content-Length"] = len(data) +# +# # format headers +# byte_header = bytearray() +# for key, value in headers.items(): +# byte_header += f"{key}: {value}\r\n".encode("ascii") +# +# # send response to the client +# client.sendall( +# b'HTTP/1.1 ' + +# get_response_code(response) + +# b'\r\n' + +# byte_header + # if empty, we'll just get b'\r\n\r\n' +# b'\r\n' + +# data +# ) +# +# def _recvall(self, client: ssl.SSLSocket) -> bytes: +# """ +# Receive All (just receives the whole message, instead of 1 packet at a time) +# """ +# +# # create message buffer +# buffer: bytearray = bytearray() +# +# # start fetching the message +# while True: +# try: +# # fetch packet +# message = ssl_sock_recv(client, self.packet_size) +# except OSError: +# break +# +# # that happens when user stops loading the page +# if message == b'': +# break +# +# # append fetched message to the buffer +# buffer += message +# +# # check for EoF +# if buffer[-4:] == b'\r\n\r\n': +# # return the received message +# return buffer +# +# # return empty buffer on error +# return b'' +# +# def _close_client(self, client: socket.socket): +# """ +# Closes a client +# """ +# +# client.close() +# if client in self.clients: +# self.clients.remove(client) def main(): |