1- import asyncio
21import typing as t
32from datetime import timedelta
43
5- from atproto_server .auth .jwt import get_jwt_payload
6-
74from atproto_client .client .methods_mixin .time import TimeMethodsMixin
85from atproto_client .client .session import (
96 AsyncSessionChangeCallback ,
107 Session ,
118 SessionChangeCallback ,
9+ SessionDispatcher ,
1210 SessionEvent ,
1311 SessionResponse ,
1412 get_session_pds_endpoint ,
1513)
1614from atproto_client .exceptions import LoginRequiredError
1715
18- if t .TYPE_CHECKING :
19- from atproto_server .auth .jwt import JwtPayload
20-
2116
2217class SessionDispatchMixin :
23- def __init__ (self , * args : t .Any , ** kwargs : t .Any ) -> None :
24- super ().__init__ (* args , ** kwargs )
25-
26- self ._on_session_change_callbacks : t .List [SessionChangeCallback ] = []
27- self ._on_session_change_async_callbacks : t .List [AsyncSessionChangeCallback ] = []
28-
2918 def on_session_change (self , callback : SessionChangeCallback ) -> None :
3019 """Register a callback for session change event.
3120
3221 Args:
3322 callback: A callback to be called when the session changes.
3423 The callback must accept two arguments: event and session.
3524
25+ Note:
26+ Possible events: `SessionEvent.IMPORT`, `SessionEvent.CREATE`, `SessionEvent.REFRESH`.
27+
28+ Tip:
29+ You should save the session string to persistent storage
30+ on `SessionEvent.CREATE` and `SessionEvent.REFRESH` event.
31+
3632 Example:
3733 >>> from atproto import Client, SessionEvent, Session
3834 >>>
3935 >>> client = Client()
4036 >>>
37+ >>> @client.on_session_change
4138 >>> def on_session_change(event: SessionEvent, session: Session):
4239 >>> print(event, session)
4340 >>>
44- >>> client.on_session_change(on_session_change)
41+ >>> # or you can use this syntax:
42+ >>> # client.on_session_change(on_session_change)
4543
4644 Returns:
4745 :obj:`None`
4846 """
49- self ._on_session_change_callbacks . append (callback )
47+ self ._session_dispatcher . on_session_change (callback )
5048
51- def _call_on_session_change_callbacks (self , event : SessionEvent , session : Session ) -> None :
52- for on_session_change_callback in self ._on_session_change_callbacks :
53- on_session_change_callback (event , session )
49+ def _call_on_session_change_callbacks (self , event : SessionEvent ) -> None :
50+ self ._session_dispatcher .dispatch_session_change (event )
5451
5552
5653class AsyncSessionDispatchMixin :
57- def __init__ (self , * args : t .Any , ** kwargs : t .Any ) -> None :
58- super ().__init__ (* args , ** kwargs )
59-
60- self ._on_session_change_async_callbacks : t .List [AsyncSessionChangeCallback ] = []
61-
62- def on_session_change (self , callback : AsyncSessionChangeCallback ) -> None :
54+ def on_session_change (self , callback : t .Union ['AsyncSessionChangeCallback' , 'SessionChangeCallback' ]) -> None :
6355 """Register a callback for session change event.
6456
6557 Args:
@@ -69,6 +61,9 @@ def on_session_change(self, callback: AsyncSessionChangeCallback) -> None:
6961 Note:
7062 Possible events: `SessionEvent.IMPORT`, `SessionEvent.CREATE`, `SessionEvent.REFRESH`.
7163
64+ Note:
65+ You can register both synchronous and asynchronous callbacks.
66+
7267 Tip:
7368 You should save the session string to persistent storage
7469 on `SessionEvent.CREATE` and `SessionEvent.REFRESH` event.
@@ -78,78 +73,81 @@ def on_session_change(self, callback: AsyncSessionChangeCallback) -> None:
7873 >>>
7974 >>> client = AsyncClient()
8075 >>>
76+ >>> @client.on_session_change
8177 >>> async def on_session_change(event: SessionEvent, session: Session):
8278 >>> print(event, session)
8379 >>>
84- >>> client.on_session_change(on_session_change)
80+ >>> # or you can use this syntax:
81+ >>> # client.on_session_change(on_session_change)
8582
8683 Returns:
8784 :obj:`None`
8885 """
89- self ._on_session_change_async_callbacks . append (callback )
86+ self ._session_dispatcher . on_session_change (callback )
9087
91- async def _call_on_session_change_callbacks (self , event : SessionEvent , session : Session ) -> None :
92- coroutines : t .List [t .Coroutine [t .Any , t .Any , None ]] = []
93- for on_session_change_async_callback in self ._on_session_change_async_callbacks :
94- coroutines .append (on_session_change_async_callback (event , session ))
95-
96- await asyncio .gather (* coroutines )
88+ async def _call_on_session_change_callbacks (self , event : SessionEvent ) -> None :
89+ await self ._session_dispatcher .dispatch_session_change_async (event )
9790
9891
9992class SessionMethodsMixin (TimeMethodsMixin ):
10093 def __init__ (self , * args : t .Any , ** kwargs : t .Any ) -> None :
10194 super ().__init__ (* args , ** kwargs )
102-
103- self ._access_jwt : t .Optional [str ] = None
104- self ._access_jwt_payload : t .Optional ['JwtPayload' ] = None
105-
106- self ._refresh_jwt : t .Optional [str ] = None
107- self ._refresh_jwt_payload : t .Optional ['JwtPayload' ] = None
108-
10995 self ._session : t .Optional [Session ] = None
96+ self ._session_dispatcher = SessionDispatcher ()
97+
98+ def _register_auth_headers_source (self ) -> None :
99+ self .request .add_additional_headers_source (self ._get_access_auth_headers )
110100
111101 def _should_refresh_session (self ) -> bool :
112- if not self ._access_jwt_payload or not self ._access_jwt_payload .exp :
102+ if not self ._session or not self ._session . access_jwt_payload or not self . _session . access_jwt_payload .exp :
113103 raise LoginRequiredError
114104
115- expired_at = self .get_time_from_timestamp (self ._access_jwt_payload .exp )
105+ expired_at = self .get_time_from_timestamp (self ._session . access_jwt_payload .exp )
116106 expired_at = expired_at - timedelta (minutes = 15 ) # let's update the token a bit earlier than required
117107
118108 return self .get_current_time () > expired_at
119109
120- def _set_session_common (self , session : SessionResponse , current_pds : str ) -> Session :
121- self ._access_jwt = session .access_jwt
122- self ._access_jwt_payload = get_jwt_payload (session .access_jwt )
110+ def _set_or_update_session (self , session : SessionResponse , pds_endpoint : str ) -> 'Session' :
111+ if not self ._session :
112+ self ._session = Session (
113+ access_jwt = session .access_jwt ,
114+ refresh_jwt = session .refresh_jwt ,
115+ did = session .did ,
116+ handle = session .handle ,
117+ pds_endpoint = pds_endpoint ,
118+ )
119+ self ._session_dispatcher .set_session (self ._session )
120+ self ._register_auth_headers_source ()
121+ else :
122+ self ._session .access_jwt = session .access_jwt
123+ self ._session .refresh_jwt = session .refresh_jwt
124+ self ._session .did = session .did
125+ self ._session .handle = session .handle
126+ self ._session .pds_endpoint = pds_endpoint
123127
124- self ._refresh_jwt = session .refresh_jwt
125- self ._refresh_jwt_payload = get_jwt_payload (session .refresh_jwt )
128+ return self ._session
126129
130+ def _set_session_common (self , session : SessionResponse , current_pds : str ) -> Session :
127131 pds_endpoint = get_session_pds_endpoint (session )
128132 if not pds_endpoint :
129133 # current_pds ends with xrpc endpoint, but this is not a problem
130134 # overhead is only 4-5 symbols in the exported session string
131135 pds_endpoint = current_pds
132136
133- self ._session = Session (
134- access_jwt = session .access_jwt ,
135- refresh_jwt = session .refresh_jwt ,
136- did = session .did ,
137- handle = session .handle ,
138- pds_endpoint = pds_endpoint ,
139- )
140-
141- self ._set_auth_headers (session .access_jwt )
142137 self ._update_pds_endpoint (pds_endpoint )
138+ return self ._set_or_update_session (session , pds_endpoint )
143139
144- return self ._session
140+ def _get_access_auth_headers (self ) -> t .Dict [str , str ]:
141+ if not self ._session :
142+ return {}
145143
146- @staticmethod
147- def _get_auth_headers (token : str ) -> t .Dict [str , str ]:
148- return {'Authorization' : f'Bearer { token } ' }
144+ return {'Authorization' : f'Bearer { self ._session .access_jwt } ' }
145+
146+ def _get_refresh_auth_headers (self ) -> t .Dict [str , str ]:
147+ if not self ._session :
148+ return {}
149149
150- def _set_auth_headers (self , token : str ) -> None :
151- for header_name , header_value in self ._get_auth_headers (token ).items ():
152- self .request .add_additional_header (header_name , header_value )
150+ return {'Authorization' : f'Bearer { self ._session .refresh_jwt } ' }
153151
154152 def _update_pds_endpoint (self , pds_endpoint : str ) -> None :
155153 self .update_base_url (pds_endpoint )
0 commit comments