diff --git a/tools/wptserve/wptserve/server.py b/tools/wptserve/wptserve/server.py index 82464e5e8c7347..1bd9709b23c277 100644 --- a/tools/wptserve/wptserve/server.py +++ b/tools/wptserve/wptserve/server.py @@ -322,33 +322,13 @@ def handle_error(self, request, client_address): self.logger.info(msg) -class BaseWebTestRequestHandler(http.server.BaseHTTPRequestHandler): +class BaseWebTestRequestHandler(socketserver.StreamRequestHandler): """RequestHandler for WebTestHttpd""" def __init__(self, *args, **kwargs): self.logger = get_logger() super().__init__(*args, **kwargs) - def finish_handling_h1(self, request_line_is_valid): - - self.server.rewriter.rewrite(self) - - with Request(self) as request: - response = Response(self, request) - - if request.method == "CONNECT": - self.handle_connect(response) - return - - if not request_line_is_valid: - response.set_error(414) - response.write() - return - - self.logger.debug(f"{request.method} {request.request_path}") - handler = self.server.router.get_handler(request) - self.finish_handling(request, response, handler) - def finish_handling(self, request, response, handler): # If the handler we used for the request had a non-default base path # set update the doc_root of the request to reflect this @@ -401,52 +381,9 @@ def finish_handling(self, request, response, handler): # Ensure that the whole request has been read from the socket request.raw_input.read() - def handle_connect(self, response): - self.logger.debug("Got CONNECT") - response.status = 200 - response.write() - if self.server.encrypt_after_connect: - self.logger.debug("Enabling SSL for connection") - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - ssl_context.load_cert_chain(keyfile=self.server.key_file, certfile=self.server.certificate) - self.request = ssl_context.wrap_socket(self.connection, - server_side=True) - self.setup() - return - - def log_request(self, code="-", size="-"): - if isinstance(code, http.HTTPStatus): - code = code.value - - self.logger.debug( - "{} - - [{}] {!r} {!s} {!s}".format( - self.address_string(), - self.log_date_time_string(), - self.requestline, - code, - size, - ) - ) - - def log_error(self, format, *args): - self.logger.error( - "{} - - [{}] {}".format( - self.address_string(), self.log_date_time_string(), format % args - ) - ) - - def log_message(self, format, *args): - self.logger.info( - "{} - - [{}] {}".format( - self.address_string(), self.log_date_time_string(), format % args - ) - ) - class Http2WebTestRequestHandler(BaseWebTestRequestHandler): - protocol_version = "HTTP/2.0" - - def handle_one_request(self): + def handle(self): """ This is the main HTTP/2 Handler. @@ -806,7 +743,8 @@ def __init__(self, handler, req_frame, rfile): self.request = handler.request self.conn = handler.conn -class Http1WebTestRequestHandler(BaseWebTestRequestHandler): + +class Http1WebTestRequestHandler(BaseWebTestRequestHandler, http.server.BaseHTTPRequestHandler): protocol_version = "HTTP/1.1" def handle_one_request(self): @@ -852,6 +790,68 @@ def get_request_line(self): self.close_connection = True return True + def finish_handling_h1(self, request_line_is_valid): + + self.server.rewriter.rewrite(self) + + with Request(self) as request: + response = Response(self, request) + + if request.method == "CONNECT": + self.handle_connect(response) + return + + if not request_line_is_valid: + response.set_error(414) + response.write() + return + + self.logger.debug(f"{request.method} {request.request_path}") + handler = self.server.router.get_handler(request) + self.finish_handling(request, response, handler) + + def handle_connect(self, response): + self.logger.debug("Got CONNECT") + response.status = 200 + response.write() + if self.server.encrypt_after_connect: + self.logger.debug("Enabling SSL for connection") + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ssl_context.load_cert_chain(keyfile=self.server.key_file, certfile=self.server.certificate) + self.request = ssl_context.wrap_socket(self.connection, + server_side=True) + self.setup() + return + + def log_request(self, code="-", size="-"): + if isinstance(code, http.HTTPStatus): + code = code.value + + self.logger.debug( + "{} - - [{}] {!r} {!s} {!s}".format( + self.address_string(), + self.log_date_time_string(), + self.requestline, + code, + size, + ) + ) + + def log_error(self, format, *args): + self.logger.error( + "{} - - [{}] {}".format( + self.address_string(), self.log_date_time_string(), format % args + ) + ) + + def log_message(self, format, *args): + self.logger.info( + "{} - - [{}] {}".format( + self.address_string(), self.log_date_time_string(), format % args + ) + ) + + class WebTestHttpd: """ :param host: Host from which to serve (default: 127.0.0.1)