diff options
| -rw-r--r-- | main.py | 287 |
1 files changed, 96 insertions, 191 deletions
diff --git a/main.py b/main.py index 4078d99..7889ef8 100644 --- a/main.py +++ b/main.py @@ -13,250 +13,155 @@ import aiofiles import threading -# some constants -PACKET_SIZE = 2048 -PORT = 13700 # using random port cuz why not - - -# response status codes -RESPONSE = { - 200: b'OK', - 400: b'Bad Request', - 401: b'Unauthorized', - 403: b'Forbidden', - 404: b'Not Found', - 6969: b'UwU' -} - - -def get_response(code: int) -> bytes: - return str(code).encode("ascii") + RESPONSE.get(code, b':(') - - -def is_alive(sock: socket.socket) -> bool: +class Request: """ - Checks if the socket is still alive - :param sock: socket - :return: boolean (true if socket is alive, false otherwise) + Just a request """ - return getattr(sock, '_closed', False) + def __init__(self): + self.type: str = "" + self.path: str = "" + + @staticmethod + def create(raw_request: bytes): + """ + Creates self class from raw request + :param raw_request: bytes + :return: self + """ + + # new request + request = Request() -def decode_request(req: str) -> dict[str, str | list | None]: - # request dictionary - request = dict() + # fix type and path + request.type = raw_request[:raw_request.find(b' ')].decode("ascii") + request.path = raw_request[len(request.type)+1:raw_request.find(b' ', len(request.type)+1)].decode("ascii") - # request type and path - request["type"] = req[:6].split(" ")[0] - request["path"] = req[len(request["type"]) + 1:req.find("\r\n")].split(" ")[0] + # decode headers + for raw_header in raw_request.split(b'\r\n'): + if len(pair := raw_header.decode("ascii").split(":")) == 2: + key, val = pair + val = val.strip() - # decode other headers - for line in req.split("\r\n")[1:]: - if len(split := line.split(":")) == 2: - key, value = split - value = value.lstrip(" ") + # set attribute to key value pair + setattr(request, key, val) - # write key value pair - request[key] = value + # return request + return request - return request + def __str__(self): + return '\n'.join([f"{key}: {val}" for key, val in self.__dict__.items()]) -class HTMLServer: +class HTTPServer: """ - The very cool webserver + The mighty HTTP server """ - def __init__(self): - self.sock: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.clients: list[socket.socket] = [] + def __init__(self, *, port: int, packet_size: int = 2048): + self.bind_port: int = port + self.packet_size: int = packet_size + self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - # list of allowed paths - self.allowed_paths: dict[str, dict] = { - "/": {"path": "www/index.html", "encoding": "css/html"}, - "/robots.txt": {"path": "www/robots.txt", "encoding": "text"}, - "/favicon.ico": {"path": "www/favicon.ico", "encoding": "bin"}, - "/css/styles.css": {"path": None, "encoding": "css/html"}, - } + self.clients: list[socket.socket] = [] - def run(self): + def start(self): """ - Function that starts the webserver + Method to start the web server """ - # bind the server to port and start listening - self.sock.bind(('', PORT)) - self.sock.listen() + # bind and start listening to port + self.socket.bind(('', self.bind_port)) + self.socket.listen() - # start running thread - t = threading.Thread(target=self._run, daemon=True) - t.start() + # start the listening thread + threading.Thread(target=self._listen_thread, daemon=True).start() # keep alive try: while True: + # sleep 100 ms, otherwise the while true will 100% one of your cores time.sleep(0.1) + + # shutdown on keyboard interrupt except KeyboardInterrupt: - self.sock.close() + self.socket.close() print("Closed.") - def _run(self): + def _listen_thread(self): """ - Run function for threads - :return: + Listening for new connections """ - asyncio.run(self.server_listener()) + # run the coroutine + asyncio.run(self._thread_listen_coro()) + + async def _thread_listen_coro(self): + while True: + # accept new connection, add to client list and start listening to it + client, _ = self.socket.accept() + self.clients.append(client) + await self.client_handle(client) - async def server_listener(self): + async def client_handle(self, client: socket.socket): """ - Listens for new connections, and handles them + Handles client's connection """ while True: - client, address = self.sock.accept() - self.clients.append(client) - await self.server_handle(client) + # receive request from client + raw_request = self._recvall(client) + + # decode request + request: Request = Request.create(raw_request) + + self._close_client(client) + break - async def server_handle(self, client: socket.socket): + def _close_client(self, client: socket.socket): """ - Handles the actual connections (clients) - :param client: connection socket + Closes a client """ - # message buffer - buffer = bytearray() + client.close() + if client in self.clients: + self.clients.remove(client) + + def _recvall(self, client: socket.socket) -> bytes: + """ + Receive All (just receives the whole message, instead of 1 packet at a time) + """ + + # create message buffer + buffer: bytearray = bytearray() + + # start fetching the message while True: - # try to fetch a message - # die otherwise try: - message = client.recv(PACKET_SIZE) + # fetch packet + message = client.recv(self.packet_size) except OSError: break + + # that happens when user stops loading the page if message == b'': break - # append packet to buffer + # append fetched message to the buffer buffer += message - # check EoF (2 blank lines) + # check for EoF if buffer[-4:] == b'\r\n\r\n': - # text buffer - text_buffer = buffer.decode("ascii") - - # decode request - request = decode_request(text_buffer) - - print(f"[{request['type']}] Request from client '{client.getpeername()[0]}'") - - # log that request - async with aiofiles.open("logs.log", "a") as f: - await f.write( - json.dumps( - { - "client": client.getpeername()[0], - "request": request - }, - indent=2 - ) + "\n" - ) - - # handle the request - if request["type"] == "GET": - await self.handle_get_request(client, request) - else: - await self.handle_other_request(client) - - # clear buffer - buffer.clear() - client.close() - self.clients.remove(client) - - async def handle_get_request(self, client: socket.socket, req: dict[str, str | None]): - # check if the path is too long - if len(req["path"]) > 255: - response = get_response(400) - data = b'' - - # if it's yandex - elif req.get("from") == "support@search.yandex.ru": - response = get_response(404) - data = b'Nothing...' - - # check UwU path - elif req["path"] == "/UwU": - response = get_response(6969) - data = b'<h1>' + b'UwU ' * 2000 + b'</h1>' - - # otherwise check access - elif req["path"] in self.allowed_paths: - # get path - path = self.allowed_paths[req["path"]]["path"] - - # if path is None, return generic filepath - if path is None: - path = req["path"][1:] - - # check encoding - if self.allowed_paths[req["path"]]["encoding"] == "css/html": - # return text data - async with aiofiles.open(path, "r") as f: - data = htmlmin.minify(await f.read()).encode("ascii") - else: - # return binary / text data - async with aiofiles.open(path, "rb") as f: - data = await f.read() - response = get_response(200) - - # in any other case - else: - response = get_response(403) - data = b'Idk what you are trying to do here :/' - - # make headers - headers = {} - - # check if compression is supported - if req.get("Accept-Encoding"): - encoding_list = [enc.lstrip(" ") for enc in req["Accept-Encoding"].split(",")] - - # check for gzip, and add to headers if present - if "gzip" in encoding_list: - headers["Content-Encoding"] = "gzip" - - # send response - await self.send(client, response, data, headers) - client.close() - - async def handle_other_request(self, client: socket.socket): - # just say 'no' - await self.send(client, get_response(403), b'No. Don\'t do that, that\'s cringe') - client.close() - - async def send(self, client: socket.socket, response: bytes, data: bytes, headers: dict[str, str] | None = None): - # construct headers - formatted_headers = b'' - if headers is not None: - formatted_headers = "".join([f"{key}: {val}\r\n" for key, val in headers.items()]).encode("ascii") - - # check for compression - if headers.get("Content-Encoding") == "gzip": - # compress data - data = gzip.compress(data) - - # construct message - if formatted_headers == b'': - message = b'HTTP/1.1 ' + response + b'\r\n\r\n' + data - else: - message = b'HTTP/1.1 ' + response + b'\r\n' + formatted_headers + b'\r\n' + data + # return the received message + return buffer - # send message to client - client.sendall(message) + # return empty buffer on error + return b'' def main(): - server = HTMLServer() - server.run() + server = HTTPServer(port=13701) + server.start() if __name__ == '__main__': |