|
25 | 25 | from robyn.robyn import FunctionInfo, Headers, HttpMethod, Request, Response, WebSocketConnector, get_version |
26 | 26 | from robyn.router import MiddlewareRouter, MiddlewareType, Router, WebSocketRouter |
27 | 27 | from robyn.types import Directory |
28 | | -# WebSocket functionality is now handled directly in this module |
| 28 | +from robyn.ws import WebSocketDisconnect, WebSocketAdapter, create_websocket_decorator |
29 | 29 |
|
30 | 30 | __version__ = get_version() |
31 | 31 |
|
32 | 32 |
|
33 | | -class WebSocketDisconnect(Exception): |
34 | | - """Exception raised when a WebSocket connection is disconnected.""" |
35 | | - |
36 | | - def __init__(self, code: int = 1000, reason: str = ""): |
37 | | - self.code = code |
38 | | - self.reason = reason |
39 | | - super().__init__(f"WebSocket disconnected with code {code}: {reason}") |
40 | | - |
41 | | - |
42 | | -class WebSocketAdapter: |
43 | | - """ |
44 | | - Adapter class that provides a modern WebSocket interface |
45 | | - wrapping Robyn's WebSocketConnector for compatibility. |
46 | | - """ |
47 | | - |
48 | | - def __init__(self, websocket_connector: WebSocketConnector, message: str = None): |
49 | | - self._connector = websocket_connector |
50 | | - self._message = message |
51 | | - self._accepted = False |
52 | | - |
53 | | - async def accept(self): |
54 | | - """Accept the WebSocket connection (no-op in Robyn as it's auto-accepted)""" |
55 | | - self._accepted = True |
56 | | - |
57 | | - async def close(self, code: int = 1000): |
58 | | - """Close the WebSocket connection""" |
59 | | - self._connector.close() |
60 | | - |
61 | | - async def send_text(self, data: str): |
62 | | - """Send text data to the WebSocket""" |
63 | | - await self._connector.async_send_to(self._connector.id, data) |
64 | | - |
65 | | - async def send_bytes(self, data: bytes): |
66 | | - """Send binary data to the WebSocket""" |
67 | | - await self._connector.async_send_to(self._connector.id, data.decode('utf-8')) |
68 | | - |
69 | | - async def receive_text(self) -> str: |
70 | | - """Receive text data from the WebSocket""" |
71 | | - if self._message is not None: |
72 | | - msg = self._message |
73 | | - self._message = None # Consume the message |
74 | | - return msg |
75 | | - # Note: In a real implementation, this would need to handle the message queue |
76 | | - # For now, we return the current message if available |
77 | | - return "" |
78 | | - |
79 | | - async def receive_bytes(self) -> bytes: |
80 | | - """Receive binary data from the WebSocket""" |
81 | | - text = await self.receive_text() |
82 | | - return text.encode('utf-8') |
83 | | - |
84 | | - async def send_json(self, data): |
85 | | - """Send JSON data to the WebSocket""" |
86 | | - import json |
87 | | - await self.send_text(json.dumps(data)) |
88 | | - |
89 | | - async def receive_json(self): |
90 | | - """Receive JSON data from the WebSocket""" |
91 | | - import json |
92 | | - text = await self.receive_text() |
93 | | - return json.loads(text) if text else None |
94 | | - |
95 | | - @property |
96 | | - def query_params(self): |
97 | | - """Access query parameters""" |
98 | | - return self._connector.query_params |
99 | | - |
100 | | - @property |
101 | | - def path_params(self): |
102 | | - """Access path parameters""" |
103 | | - return getattr(self._connector, 'path_params', {}) |
104 | | - |
105 | | - @property |
106 | | - def headers(self): |
107 | | - """Access request headers""" |
108 | | - return getattr(self._connector, 'headers', {}) |
109 | | - |
110 | | - @property |
111 | | - def client(self): |
112 | | - """Client information""" |
113 | | - return getattr(self._connector, 'client', None) |
114 | 33 |
|
115 | 34 |
|
116 | 35 | def _normalize_endpoint(endpoint: str) -> str: |
@@ -385,74 +304,7 @@ async def on_connect(websocket): |
385 | 304 | async def on_close(websocket): |
386 | 305 | print("Disconnected") |
387 | 306 | """ |
388 | | - def decorator(handler): |
389 | | - # Dictionary to store handlers for this WebSocket endpoint |
390 | | - handlers = {} |
391 | | - |
392 | | - # Create the main message handler |
393 | | - async def message_handler(websocket_connector, msg, *args, **kwargs): |
394 | | - # Convert WebSocketConnector to modern WebSocket interface |
395 | | - websocket_adapter = WebSocketAdapter(websocket_connector, msg) |
396 | | - try: |
397 | | - # Call the user's handler |
398 | | - result = await handler(websocket_adapter) |
399 | | - return result if result is not None else "" |
400 | | - except WebSocketDisconnect: |
401 | | - # Handle disconnections gracefully |
402 | | - return "" |
403 | | - except Exception as e: |
404 | | - # Handle other connection errors gracefully |
405 | | - if "connection closed" in str(e).lower() or "websocket" in str(e).lower(): |
406 | | - return "" |
407 | | - raise e |
408 | | - |
409 | | - # Create FunctionInfo for the message handler |
410 | | - params = dict(inspect.signature(message_handler).parameters) |
411 | | - num_params = len(params) |
412 | | - is_async = asyncio.iscoroutinefunction(message_handler) |
413 | | - injected_dependencies = self.dependencies.get_dependency_map(self) |
414 | | - |
415 | | - handlers["message"] = FunctionInfo(message_handler, is_async, num_params, params, kwargs=injected_dependencies) |
416 | | - |
417 | | - # Add methods to the handler to allow attaching on_connect and on_close |
418 | | - def add_on_connect(connect_handler): |
419 | | - def connect_wrapper(websocket_connector, *args, **kwargs): |
420 | | - websocket_adapter = WebSocketAdapter(websocket_connector) |
421 | | - if asyncio.iscoroutinefunction(connect_handler): |
422 | | - return asyncio.create_task(connect_handler(websocket_adapter)) |
423 | | - return connect_handler(websocket_adapter) |
424 | | - |
425 | | - # Create FunctionInfo for connect handler |
426 | | - connect_params = dict(inspect.signature(connect_wrapper).parameters) |
427 | | - connect_num_params = len(connect_params) |
428 | | - connect_is_async = asyncio.iscoroutinefunction(connect_wrapper) |
429 | | - handlers["connect"] = FunctionInfo(connect_wrapper, connect_is_async, connect_num_params, connect_params, kwargs=injected_dependencies) |
430 | | - return connect_handler |
431 | | - |
432 | | - def add_on_close(close_handler): |
433 | | - def close_wrapper(websocket_connector, *args, **kwargs): |
434 | | - websocket_adapter = WebSocketAdapter(websocket_connector) |
435 | | - if asyncio.iscoroutinefunction(close_handler): |
436 | | - return asyncio.create_task(close_handler(websocket_adapter)) |
437 | | - return close_handler(websocket_adapter) |
438 | | - |
439 | | - # Create FunctionInfo for close handler |
440 | | - close_params = dict(inspect.signature(close_wrapper).parameters) |
441 | | - close_num_params = len(close_params) |
442 | | - close_is_async = asyncio.iscoroutinefunction(close_wrapper) |
443 | | - handlers["close"] = FunctionInfo(close_wrapper, close_is_async, close_num_params, close_params, kwargs=injected_dependencies) |
444 | | - return close_handler |
445 | | - |
446 | | - # Attach methods to the handler function |
447 | | - handler.on_connect = add_on_connect |
448 | | - handler.on_close = add_on_close |
449 | | - handler._ws_handlers = handlers # Store reference to handlers dict |
450 | | - |
451 | | - # Add the WebSocket to the router |
452 | | - self.add_web_socket(endpoint, handlers) |
453 | | - return handler |
454 | | - |
455 | | - return decorator |
| 307 | + return create_websocket_decorator(self)(endpoint) |
456 | 308 |
|
457 | 309 | def _add_event_handler(self, event_type: Events, handler: Callable) -> None: |
458 | 310 | logger.info("Added event %s handler", event_type) |
@@ -874,7 +726,7 @@ async def on_connect(websocket): |
874 | 726 | await websocket.send_text("Connected!") |
875 | 727 | """ |
876 | 728 | prefixed_endpoint = self.__add_prefix(endpoint) |
877 | | - return super().websocket(prefixed_endpoint) |
| 729 | + return create_websocket_decorator(self)(prefixed_endpoint) |
878 | 730 |
|
879 | 731 |
|
880 | 732 | def ALLOW_CORS(app: Robyn, origins: Union[List[str], str], headers: Union[List[str], str] = None): |
|
0 commit comments