diff --git a/pyk/src/pyk/rpc/rpc.py b/pyk/src/pyk/rpc/rpc.py index fa0f5c65e76..7a11b6bcc80 100644 --- a/pyk/src/pyk/rpc/rpc.py +++ b/pyk/src/pyk/rpc/rpc.py @@ -3,10 +3,11 @@ import json import logging from abc import ABC, abstractmethod +from collections.abc import Iterator from dataclasses import dataclass from functools import partial from http.server import BaseHTTPRequestHandler, HTTPServer -from typing import TYPE_CHECKING, Any, Final, NamedTuple +from typing import TYPE_CHECKING, NamedTuple from typing_extensions import Protocol @@ -15,6 +16,7 @@ if TYPE_CHECKING: from collections.abc import Callable from pathlib import Path + from typing import Any, Final _LOGGER: Final = logging.getLogger(__name__) @@ -86,7 +88,7 @@ class JsonRpcBatchRequest(NamedTuple): class JsonRpcResult(ABC): @abstractmethod - def encode(self) -> bytes: ... + def encode(self) -> Iterator[bytes]: ... @dataclass(frozen=True) @@ -96,7 +98,7 @@ class JsonRpcError(JsonRpcResult): message: str id: str | int | None - def to_json(self) -> dict[str, Any]: + def wrap_response(self) -> dict[str, Any]: return { 'jsonrpc': JsonRpcServer.JSONRPC_VERSION, 'error': { @@ -106,8 +108,8 @@ def to_json(self) -> dict[str, Any]: 'id': self.id, } - def encode(self) -> bytes: - return json.dumps(self.to_json()).encode('ascii') + def encode(self) -> Iterator[bytes]: + yield json.dumps(self.wrap_response()).encode('utf-8') @dataclass(frozen=True) @@ -115,23 +117,31 @@ class JsonRpcSuccess(JsonRpcResult): payload: Any id: Any - def to_json(self) -> dict[str, Any]: - return { - 'jsonrpc': JsonRpcServer.JSONRPC_VERSION, - 'result': self.payload, - 'id': self.id, - } - - def encode(self) -> bytes: - return json.dumps(self.to_json()).encode('ascii') + def encode(self) -> Iterator[bytes]: + id_encoded = json.dumps(self.id) + version_encoded = json.dumps(JsonRpcServer.JSONRPC_VERSION) + yield f'{{"jsonrpc": {version_encoded}, "id": {id_encoded}, "result": '.encode() + if isinstance(self.payload, Iterator): + yield from self.payload + else: + yield json.dumps(self.payload).encode('utf-8') + yield b'}' @dataclass(frozen=True) class JsonRpcBatchResult(JsonRpcResult): results: tuple[JsonRpcError | JsonRpcSuccess, ...] - def encode(self) -> bytes: - return json.dumps([result.to_json() for result in self.results]).encode('ascii') + def encode(self) -> Iterator[bytes]: + yield b'[' + first = True + for result in self.results: + if not first: + yield b',' + else: + first = False + yield from result.encode() + yield b']' class JsonRpcRequestHandler(BaseHTTPRequestHandler): @@ -143,8 +153,10 @@ def __init__(self, methods: dict[str, JsonRpcMethod], *args: Any, **kwargs: Any) def _send_response(self, response: JsonRpcResult) -> None: self.send_response_headers() - response_bytes = response.encode() - self.wfile.write(response_bytes) + response_body = response.encode() + for chunk in response_body: + self.wfile.write(chunk) + self.wfile.flush() def send_response_headers(self) -> None: self.send_response(200) diff --git a/pyk/src/tests/integration/test_json_rpc.py b/pyk/src/tests/integration/test_json_rpc.py index 61a5367df1b..63cb3798587 100644 --- a/pyk/src/tests/integration/test_json_rpc.py +++ b/pyk/src/tests/integration/test_json_rpc.py @@ -15,6 +15,7 @@ from pyk.testing import KRunTest if TYPE_CHECKING: + from collections.abc import Iterator from typing import Any @@ -154,6 +155,7 @@ def __init__(self, options: ServeRpcOptions) -> None: self.register_method('set_x', self.exec_set_x) self.register_method('set_y', self.exec_set_y) self.register_method('add', self.exec_add) + self.register_method('streaming', self.exec_streaming) def exec_get_x(self) -> int: return self.x @@ -170,6 +172,11 @@ def exec_set_y(self, n: int) -> None: def exec_add(self) -> int: return self.x + self.y + def exec_streaming(self) -> Iterator[bytes]: + yield b'{' + yield b'"foo": "bar"' + yield b'}' + class TestJsonRPCServer(KRunTest): @@ -221,6 +228,15 @@ def wait_until_ready() -> None: assert len(res) == 3 assert res[2]['result'] == 1 + 2 + res = rpc_client.request('streaming', []) + assert res == {'foo': 'bar'} + + res = rpc_client.batch_request(('streaming', []), ('set_x', [10]), ('streaming', [])) + assert len(res) == 3 + assert res[0]['result'] == {'foo': 'bar'} + assert res[1]['result'] == None + assert res[2]['result'] == {'foo': 'bar'} + server.shutdown() thread.join()