diff --git a/py3langid/langid.py b/py3langid/langid.py index 9f550ba..2358ea6 100755 --- a/py3langid/langid.py +++ b/py3langid/langid.py @@ -286,12 +286,19 @@ def rank_path(self, path): class NumpyEncoder(json.JSONEncoder): """ Custom encoder for numpy data types """ - def default(self, obj): - if isinstance(obj, np.float32): - return float(obj) # Convert float32 to native float - if isinstance(obj, np.ndarray): - return obj.tolist() # Convert arrays to list - return json.JSONEncoder.default(self, obj) + def default(self, o): + if isinstance(o, np.float32): + return float(o) # Convert float32 to native float + if isinstance(o, np.ndarray): + return o.tolist() # Convert arrays to list + return json.JSONEncoder.default(self, o) + + +METHODS = { + 'detect': lambda data: {'language': classify(data)[0], 'confidence': classify(data)[1]}, + 'rank': lambda data: rank(data) +} + def application(environ, start_response): """ @@ -304,59 +311,51 @@ def application(environ, start_response): # Catch shift_path_info's failure to handle empty paths properly path = '' - if path in {'detect', 'rank'}: - data = None + if path not in METHODS: + return _return_response(start_response, 404, None, 'Not found') - # Extract the data component from different access methods - if environ['REQUEST_METHOD'] == 'PUT': - data = environ['wsgi.input'].read(int(environ['CONTENT_LENGTH'])) - elif environ['REQUEST_METHOD'] == 'GET': - try: - data = parse_qs(environ['QUERY_STRING'])['q'][0] - except KeyError: - # No query, provide a null response. - status = '200 OK' # HTTP Status - response = { - 'responseData': None, - 'responseStatus': 200, - 'responseDetails': None, - } - elif environ['REQUEST_METHOD'] == 'POST': - input_string = environ['wsgi.input'].read(int(environ['CONTENT_LENGTH'])) + data = _get_data(environ) + if data is None: + if environ['REQUEST_METHOD'] == 'GET' and 'QUERY_STRING' not in environ: + return _return_response(start_response, 400, None, 'Missing query string') + return _return_response(start_response, 405, None, f"{environ['REQUEST_METHOD']} not allowed") + + response_data = METHODS[path](data) + return _return_response(start_response, 200, response_data, None) + + +def _get_data(environ): + if environ['REQUEST_METHOD'] in ['PUT', 'POST']: + data = environ['wsgi.input'].read(int(environ['CONTENT_LENGTH'])) + if environ['REQUEST_METHOD'] == 'POST': try: - data = parse_qs(input_string)['q'][0] + data = parse_qs(data)['q'][0] except KeyError: - # No key 'q', process the whole input instead - data = input_string - else: - # Unsupported method - status = '405 Method Not Allowed' # HTTP Status - response = { - 'responseData': None, - 'responseStatus': 405, - 'responseDetails': f"{environ['REQUEST_METHOD']} not allowed", - } - - if data is not None: - if path == 'detect': - pred, conf = classify(data) - response_data = {'language': pred, 'confidence': conf} - elif path == 'rank': - response_data = rank(data) - - status = '200 OK' # HTTP Status - response = { - 'responseData': response_data, - 'responseStatus': 200, - 'responseDetails': None, - } + pass + return data + if environ['REQUEST_METHOD'] == 'GET': + try: + return parse_qs(environ['QUERY_STRING'])['q'][0] + except KeyError: + pass + return None + + +STATUS_MESSAGES = { + 200: "OK", + 404: "Not Found", + 405: "Method Not Allowed" +} - else: - # Incorrect URL - status = '404 Not Found' - response = {'responseData': None, 'responseStatus': 404, 'responseDetails': 'Not found'} - headers = [('Content-type', 'text/javascript; charset=utf-8')] # HTTP Headers +def _return_response(start_response, status_code, response_data, response_details): + status = f"{status_code} {STATUS_MESSAGES.get(status_code, 'Unknown Status')}" + response = { + 'responseData': response_data, + 'responseStatus': status_code, + 'responseDetails': response_details, + } + headers = [('Content-type', 'text/javascript; charset=utf-8')] start_response(status, headers) return [json.dumps(response, cls=NumpyEncoder).encode('utf-8')] @@ -438,7 +437,7 @@ def _process(text): else: hostname = options.host - print("Listening on %s:%d" % (hostname, int(options.port))) + print(f"Listening on {hostname}:%{options.port}") print("Press Ctrl+C to exit") httpd = make_server(hostname, int(options.port), application) try: diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..81bca3b --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,109 @@ +import json + +from unittest.mock import MagicMock + +import pytest + +from py3langid.langid import application + + +@pytest.fixture +def mock_start_response(): + return MagicMock() + +def test_detect_put(mock_start_response): + environ = { + 'REQUEST_METHOD': 'PUT', + 'CONTENT_LENGTH': 10, + 'wsgi.input': MagicMock(read=lambda x: b'This is a test'), + 'PATH_INFO': '/detect' + } + response = application(environ, mock_start_response) + assert mock_start_response.call_args[0][0] == '200 OK' + assert json.loads(response[0].decode('utf-8'))['responseData']['language'] == 'en' + +def test_detect_get(mock_start_response): + environ = { + 'REQUEST_METHOD': 'GET', + 'QUERY_STRING': 'q=This+is+a+test', + 'PATH_INFO': '/detect' + } + response = application(environ, mock_start_response) + assert mock_start_response.call_args[0][0] == '200 OK' + assert json.loads(response[0].decode('utf-8'))['responseData']['language'] == 'en' + +def test_detect_post(mock_start_response): + environ = { + 'REQUEST_METHOD': 'POST', + 'CONTENT_LENGTH': 10, + 'wsgi.input': MagicMock(read=lambda x: b'q=Hello+World'), + 'PATH_INFO': '/detect' + } + response = application(environ, mock_start_response) + assert mock_start_response.call_args[0][0] == '200 OK' + assert json.loads(response[0].decode('utf-8'))['responseData']['language'] == 'en' + +def test_rank_put(mock_start_response): + environ = { + 'REQUEST_METHOD': 'PUT', + 'CONTENT_LENGTH': 10, + 'wsgi.input': MagicMock(read=lambda x: b'Hello World'), + 'PATH_INFO': '/rank' + } + response = application(environ, mock_start_response) + assert mock_start_response.call_args[0][0] == '200 OK' + assert json.loads(response[0].decode('utf-8'))['responseData'] is not None + +def test_rank_get(mock_start_response): + environ = { + 'REQUEST_METHOD': 'GET', + 'QUERY_STRING': 'q=Hello+World', + 'PATH_INFO': '/rank' + } + response = application(environ, mock_start_response) + assert mock_start_response.call_args[0][0] == '200 OK' + assert json.loads(response[0].decode('utf-8'))['responseData'] is not None + +def test_rank_post(mock_start_response): + environ = { + 'REQUEST_METHOD': 'POST', + 'CONTENT_LENGTH': 10, + 'wsgi.input': MagicMock(read=lambda x: b'q=Hello+World'), + 'PATH_INFO': '/rank' + } + response = application(environ, mock_start_response) + assert mock_start_response.call_args[0][0] == '200 OK' + assert json.loads(response[0].decode('utf-8'))['responseData'] is not None + +def test_invalid_method(mock_start_response): + environ = { + 'REQUEST_METHOD': 'DELETE', + 'PATH_INFO': '/detect' + } + response = application(environ, mock_start_response) + assert mock_start_response.call_args[0][0] == '405 Method Not Allowed' + +def test_invalid_path(mock_start_response): + environ = { + 'REQUEST_METHOD': 'GET', + 'PATH_INFO': '/invalid' + } + response = application(environ, mock_start_response) + assert mock_start_response.call_args[0][0] == '404 Not Found' + +def test_empty_path(mock_start_response): + environ = { + 'REQUEST_METHOD': 'GET', + 'PATH_INFO': '' + } + response = application(environ, mock_start_response) + assert mock_start_response.call_args[0][0] == '404 Not Found' + +def test_no_query_string(mock_start_response): + environ = { + 'REQUEST_METHOD': 'GET', + 'PATH_INFO': '/detect' + } + response = application(environ, mock_start_response) + assert mock_start_response.call_args[0][0] == '400 Unknown Status' + assert json.loads(response[0].decode('utf-8'))['responseData'] is None