1+ import asyncio
2+ import socket
3+ from typing import Optional , Protocol
4+ from starlette .websockets import WebSocket , WebSocketDisconnect
5+
6+
7+ class AsyncDuplex (Protocol ):
8+ async def read (self , n : int = - 1 ) -> bytes : ...
9+ async def write (self , data : bytes ): ...
10+ async def close (self ): ...
11+
12+
13+ async def pipe_duplex (a : AsyncDuplex , b : AsyncDuplex , label_a = "A" , label_b = "B" ):
14+ """双向管道:a <-> b"""
15+ task_ab = asyncio .create_task (_pipe_oneway (a , b , f"{ label_a } ->{ label_b } " ))
16+ task_ba = asyncio .create_task (_pipe_oneway (b , a , f"{ label_b } ->{ label_a } " ))
17+ done , pending = await asyncio .wait (
18+ [task_ab , task_ba ],
19+ return_when = asyncio .FIRST_COMPLETED ,
20+ )
21+ for t in pending :
22+ t .cancel ()
23+ if pending :
24+ await asyncio .gather (* pending , return_exceptions = True )
25+
26+
27+ async def _pipe_oneway (src : AsyncDuplex , dst : AsyncDuplex , name : str ):
28+ try :
29+ while True :
30+ data = await src .read (4096 )
31+ if not data :
32+ break
33+ await dst .write (data )
34+ except asyncio .CancelledError :
35+ pass
36+ except Exception as e :
37+ print (f"[{ name } ] error:" , e )
38+ finally :
39+ await dst .close ()
40+
41+ class RWSocketDuplex :
42+ def __init__ (self , rsock : socket .socket , wsock : socket .socket , loop = None ):
43+ self .rsock = rsock
44+ self .wsock = wsock
45+ self ._same = rsock is wsock
46+ self .loop = loop or asyncio .get_running_loop ()
47+ self ._closed = False
48+
49+ self .rsock .setblocking (False )
50+ if not self ._same :
51+ self .wsock .setblocking (False )
52+
53+ async def read (self , n : int = 4096 ) -> bytes :
54+ if self ._closed :
55+ return b''
56+ try :
57+ data = await self .loop .sock_recv (self .rsock , n )
58+ if not data :
59+ await self .close ()
60+ return b''
61+ return data
62+ except (ConnectionResetError , OSError ):
63+ await self .close ()
64+ return b''
65+
66+ async def write (self , data : bytes ):
67+ if not data or self ._closed :
68+ return
69+ try :
70+ await self .loop .sock_sendall (self .wsock , data )
71+ except (ConnectionResetError , OSError ):
72+ await self .close ()
73+
74+ async def close (self ):
75+ if self ._closed :
76+ return
77+ self ._closed = True
78+ try :
79+ self .rsock .close ()
80+ except Exception :
81+ pass
82+ if not self ._same :
83+ try :
84+ self .wsock .close ()
85+ except Exception :
86+ pass
87+
88+ def is_closed (self ):
89+ return self ._closed
90+
91+ class SocketDuplex (RWSocketDuplex ):
92+ """封装 socket.socket 为 AsyncDuplex 接口"""
93+ def __init__ (self , sock : socket .socket , loop : Optional [asyncio .AbstractEventLoop ] = None ):
94+ super ().__init__ (sock , sock , loop )
95+
96+
97+ class WebSocketDuplex :
98+ """将 starlette.websockets.WebSocket 封装为 AsyncDuplex"""
99+ def __init__ (self , ws : WebSocket ):
100+ self .ws = ws
101+ self ._closed = False
102+
103+ async def read (self , n : int = - 1 ) -> bytes :
104+ """读取二进制消息,如果是文本则自动转 bytes"""
105+ if self ._closed :
106+ return b''
107+ try :
108+ msg = await self .ws .receive ()
109+ except WebSocketDisconnect :
110+ self ._closed = True
111+ return b''
112+ except Exception :
113+ self ._closed = True
114+ return b''
115+
116+ if msg ["type" ] == "websocket.disconnect" :
117+ self ._closed = True
118+ return b''
119+ elif msg ["type" ] == "websocket.receive" :
120+ data = msg .get ("bytes" )
121+ if data is not None :
122+ return data
123+ text = msg .get ("text" )
124+ return text .encode ("utf-8" ) if text else b''
125+ return b''
126+
127+ async def write (self , data : bytes ):
128+ if self ._closed :
129+ return
130+ try :
131+ await self .ws .send_bytes (data )
132+ except Exception :
133+ self ._closed = True
134+
135+ async def close (self ):
136+ if not self ._closed :
137+ self ._closed = True
138+ try :
139+ await self .ws .close ()
140+ except Exception :
141+ pass
0 commit comments