Skip to content

Commit df93518

Browse files
authored
Fix session sharing with all cloned client instances (#531)
1 parent 850f872 commit df93518

File tree

11 files changed

+295
-99
lines changed

11 files changed

+295
-99
lines changed

examples/advanced_usage/add_user_to_list.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def main() -> None:
3434
mod_list = client.app.bsky.graph.get_list(models.AppBskyGraphGetList.Params(list=mod_list_uri))
3535
mod_list_users = [item.subject.did for item in mod_list.items]
3636
print(f'List users: {mod_list_users}')
37-
assert user_to_add in mod_list_users, f'User {user_to_add} not found in the list {mod_list_uri}' # noqa: S101
37+
assert user_to_add in mod_list_users, f'User {user_to_add} not found in the list {mod_list_uri}'
3838

3939
deleted_success = client.app.bsky.graph.listitem.delete(mod_list_owner, AtUri.from_str(created_list_item.uri).rkey)
4040
print(f'Deleted list item: {deleted_success}')

examples/advanced_usage/validate_string_formats.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,19 @@
77
strict_validation_context = {'strict_string_format': True}
88
HandleTypeAdapter = TypeAdapter(string_formats.Handle)
99

10-
assert string_formats._OPT_IN_KEY == 'strict_string_format' # noqa: S101
10+
assert string_formats._OPT_IN_KEY == 'strict_string_format'
1111

1212
# values will not be validated if not opting in
1313
sneaky_bad_handle = HandleTypeAdapter.validate_python(some_bad_handle)
1414

15-
assert sneaky_bad_handle == some_bad_handle # noqa: S101
15+
assert sneaky_bad_handle == some_bad_handle
1616

1717
print(f'{sneaky_bad_handle=}\n\n')
1818

1919
# values will be validated if opting in
2020
validated_good_handle = HandleTypeAdapter.validate_python(some_good_handle, context=strict_validation_context)
2121

22-
assert validated_good_handle == some_good_handle # noqa: S101
22+
assert validated_good_handle == some_good_handle
2323

2424
print(f'{validated_good_handle=}\n\n')
2525

packages/atproto_client/client/async_client.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@ async def _invoke(self, invoke_type: 'InvokeType', **kwargs: t.Any) -> 'Response
4343
return await super()._invoke(invoke_type, **kwargs)
4444

4545
async with self._refresh_lock:
46-
if self._access_jwt and self._should_refresh_session():
46+
if self._session and self._session.access_jwt and self._should_refresh_session():
4747
await self._refresh_and_set_session()
4848

4949
return await super()._invoke(invoke_type, **kwargs)
5050

5151
async def _set_session(self, event: SessionEvent, session: SessionResponse) -> None:
52-
session = self._set_session_common(session, self._base_url)
53-
await self._call_on_session_change_callbacks(event, session.copy())
52+
self._set_session_common(session, self._base_url)
53+
await self._call_on_session_change_callbacks(event)
5454

5555
async def _get_and_set_session(self, login: str, password: str) -> 'models.ComAtprotoServerCreateSession.Response':
5656
session = await self.com.atproto.server.create_session(
@@ -60,11 +60,11 @@ async def _get_and_set_session(self, login: str, password: str) -> 'models.ComAt
6060
return session
6161

6262
async def _refresh_and_set_session(self) -> 'models.ComAtprotoServerRefreshSession.Response':
63-
if not self._refresh_jwt:
63+
if not self._session or not self._session.refresh_jwt:
6464
raise LoginRequiredError
6565

6666
refresh_session = await self.com.atproto.server.refresh_session(
67-
headers=self._get_auth_headers(self._refresh_jwt), session_refreshing=True
67+
headers=self._get_refresh_auth_headers(), session_refreshing=True
6868
)
6969
await self._set_session(SessionEvent.REFRESH, refresh_session)
7070

@@ -225,16 +225,16 @@ async def send_images(
225225
image_alts = image_alts + [''] * diff # [''] * (minus) => []
226226

227227
if image_aspect_ratios is None:
228-
image_aspect_ratios = [None] * len(images)
228+
aligned_image_aspect_ratios = [None] * len(images)
229229
else:
230230
# padding with None if len is insufficient
231231
diff = len(images) - len(image_aspect_ratios)
232-
image_aspect_ratios = image_aspect_ratios + [None] * diff
232+
aligned_image_aspect_ratios = image_aspect_ratios + [None] * diff
233233

234234
uploads = await asyncio.gather(*[self.upload_blob(image) for image in images])
235235
embed_images = [
236236
models.AppBskyEmbedImages.Image(alt=alt, image=upload.blob, aspect_ratio=aspect_ratio)
237-
for alt, upload, aspect_ratio in zip(image_alts, uploads, image_aspect_ratios)
237+
for alt, upload, aspect_ratio in zip(image_alts, uploads, aligned_image_aspect_ratios)
238238
]
239239

240240
return await self.send_post(
@@ -278,6 +278,10 @@ async def send_image(
278278
Raises:
279279
:class:`atproto.exceptions.AtProtocolError`: Base exception.
280280
"""
281+
image_aspect_ratios = None
282+
if image_aspect_ratio:
283+
image_aspect_ratios = [image_aspect_ratio]
284+
281285
return await self.send_images(
282286
text,
283287
images=[image],
@@ -286,7 +290,7 @@ async def send_image(
286290
reply_to=reply_to,
287291
langs=langs,
288292
facets=facets,
289-
image_aspect_ratios=[image_aspect_ratio],
293+
image_aspect_ratios=image_aspect_ratios,
290294
)
291295

292296
async def send_video(

packages/atproto_client/client/client.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@ def _invoke(self, invoke_type: 'InvokeType', **kwargs: t.Any) -> 'Response':
3434
return super()._invoke(invoke_type, **kwargs)
3535

3636
with self._refresh_lock:
37-
if self._access_jwt and self._should_refresh_session():
37+
if self._session and self._session.access_jwt and self._should_refresh_session():
3838
self._refresh_and_set_session()
3939

4040
return super()._invoke(invoke_type, **kwargs)
4141

4242
def _set_session(self, event: SessionEvent, session: SessionResponse) -> None:
43-
session = self._set_session_common(session, self._base_url)
44-
self._call_on_session_change_callbacks(event, session.copy())
43+
self._set_session_common(session, self._base_url)
44+
self._call_on_session_change_callbacks(event)
4545

4646
def _get_and_set_session(self, login: str, password: str) -> 'models.ComAtprotoServerCreateSession.Response':
4747
session = self.com.atproto.server.create_session(
@@ -51,11 +51,11 @@ def _get_and_set_session(self, login: str, password: str) -> 'models.ComAtprotoS
5151
return session
5252

5353
def _refresh_and_set_session(self) -> 'models.ComAtprotoServerRefreshSession.Response':
54-
if not self._refresh_jwt:
54+
if not self._session or not self._session.refresh_jwt:
5555
raise LoginRequiredError
5656

5757
refresh_session = self.com.atproto.server.refresh_session(
58-
headers=self._get_auth_headers(self._refresh_jwt), session_refreshing=True
58+
headers=self._get_refresh_auth_headers(), session_refreshing=True
5959
)
6060
self._set_session(SessionEvent.REFRESH, refresh_session)
6161

@@ -216,16 +216,16 @@ def send_images(
216216
image_alts = image_alts + [''] * diff # [''] * (minus) => []
217217

218218
if image_aspect_ratios is None:
219-
image_aspect_ratios = [None] * len(images)
219+
aligned_image_aspect_ratios = [None] * len(images)
220220
else:
221221
# padding with None if len is insufficient
222222
diff = len(images) - len(image_aspect_ratios)
223-
image_aspect_ratios = image_aspect_ratios + [None] * diff
223+
aligned_image_aspect_ratios = image_aspect_ratios + [None] * diff
224224

225225
uploads = [self.upload_blob(image) for image in images]
226226
embed_images = [
227227
models.AppBskyEmbedImages.Image(alt=alt, image=upload.blob, aspect_ratio=aspect_ratio)
228-
for alt, upload, aspect_ratio in zip(image_alts, uploads, image_aspect_ratios)
228+
for alt, upload, aspect_ratio in zip(image_alts, uploads, aligned_image_aspect_ratios)
229229
]
230230

231231
return self.send_post(
@@ -269,6 +269,10 @@ def send_image(
269269
Raises:
270270
:class:`atproto.exceptions.AtProtocolError`: Base exception.
271271
"""
272+
image_aspect_ratios = None
273+
if image_aspect_ratio:
274+
image_aspect_ratios = [image_aspect_ratio]
275+
272276
return self.send_images(
273277
text,
274278
images=[image],
@@ -277,7 +281,7 @@ def send_image(
277281
reply_to=reply_to,
278282
langs=langs,
279283
facets=facets,
280-
image_aspect_ratios=[image_aspect_ratio],
284+
image_aspect_ratios=image_aspect_ratios,
281285
)
282286

283287
def send_video(

packages/atproto_client/client/methods_mixin/headers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@ def clone(self) -> te.Self:
2929
Cloned client instance.
3030
"""
3131
cloned_client = super().clone()
32+
33+
# share the same objects to avoid conflicts with session changes
3234
cloned_client.me = self.me
35+
cloned_client._session = self._session
36+
cloned_client._session_dispatcher = self._session_dispatcher
37+
3338
return cloned_client
3439

3540
def with_proxy(self, service_type: t.Union[AtprotoServiceType, str], did: str) -> te.Self:

packages/atproto_client/client/methods_mixin/session.py

Lines changed: 59 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,57 @@
1-
import asyncio
21
import typing as t
32
from datetime import timedelta
43

5-
from atproto_server.auth.jwt import get_jwt_payload
6-
74
from atproto_client.client.methods_mixin.time import TimeMethodsMixin
85
from atproto_client.client.session import (
96
AsyncSessionChangeCallback,
107
Session,
118
SessionChangeCallback,
9+
SessionDispatcher,
1210
SessionEvent,
1311
SessionResponse,
1412
get_session_pds_endpoint,
1513
)
1614
from atproto_client.exceptions import LoginRequiredError
1715

18-
if t.TYPE_CHECKING:
19-
from atproto_server.auth.jwt import JwtPayload
20-
2116

2217
class SessionDispatchMixin:
23-
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
24-
super().__init__(*args, **kwargs)
25-
26-
self._on_session_change_callbacks: t.List[SessionChangeCallback] = []
27-
self._on_session_change_async_callbacks: t.List[AsyncSessionChangeCallback] = []
28-
2918
def on_session_change(self, callback: SessionChangeCallback) -> None:
3019
"""Register a callback for session change event.
3120
3221
Args:
3322
callback: A callback to be called when the session changes.
3423
The callback must accept two arguments: event and session.
3524
25+
Note:
26+
Possible events: `SessionEvent.IMPORT`, `SessionEvent.CREATE`, `SessionEvent.REFRESH`.
27+
28+
Tip:
29+
You should save the session string to persistent storage
30+
on `SessionEvent.CREATE` and `SessionEvent.REFRESH` event.
31+
3632
Example:
3733
>>> from atproto import Client, SessionEvent, Session
3834
>>>
3935
>>> client = Client()
4036
>>>
37+
>>> @client.on_session_change
4138
>>> def on_session_change(event: SessionEvent, session: Session):
4239
>>> print(event, session)
4340
>>>
44-
>>> client.on_session_change(on_session_change)
41+
>>> # or you can use this syntax:
42+
>>> # client.on_session_change(on_session_change)
4543
4644
Returns:
4745
:obj:`None`
4846
"""
49-
self._on_session_change_callbacks.append(callback)
47+
self._session_dispatcher.on_session_change(callback)
5048

51-
def _call_on_session_change_callbacks(self, event: SessionEvent, session: Session) -> None:
52-
for on_session_change_callback in self._on_session_change_callbacks:
53-
on_session_change_callback(event, session)
49+
def _call_on_session_change_callbacks(self, event: SessionEvent) -> None:
50+
self._session_dispatcher.dispatch_session_change(event)
5451

5552

5653
class AsyncSessionDispatchMixin:
57-
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
58-
super().__init__(*args, **kwargs)
59-
60-
self._on_session_change_async_callbacks: t.List[AsyncSessionChangeCallback] = []
61-
62-
def on_session_change(self, callback: AsyncSessionChangeCallback) -> None:
54+
def on_session_change(self, callback: t.Union['AsyncSessionChangeCallback', 'SessionChangeCallback']) -> None:
6355
"""Register a callback for session change event.
6456
6557
Args:
@@ -69,6 +61,9 @@ def on_session_change(self, callback: AsyncSessionChangeCallback) -> None:
6961
Note:
7062
Possible events: `SessionEvent.IMPORT`, `SessionEvent.CREATE`, `SessionEvent.REFRESH`.
7163
64+
Note:
65+
You can register both synchronous and asynchronous callbacks.
66+
7267
Tip:
7368
You should save the session string to persistent storage
7469
on `SessionEvent.CREATE` and `SessionEvent.REFRESH` event.
@@ -78,78 +73,81 @@ def on_session_change(self, callback: AsyncSessionChangeCallback) -> None:
7873
>>>
7974
>>> client = AsyncClient()
8075
>>>
76+
>>> @client.on_session_change
8177
>>> async def on_session_change(event: SessionEvent, session: Session):
8278
>>> print(event, session)
8379
>>>
84-
>>> client.on_session_change(on_session_change)
80+
>>> # or you can use this syntax:
81+
>>> # client.on_session_change(on_session_change)
8582
8683
Returns:
8784
:obj:`None`
8885
"""
89-
self._on_session_change_async_callbacks.append(callback)
86+
self._session_dispatcher.on_session_change(callback)
9087

91-
async def _call_on_session_change_callbacks(self, event: SessionEvent, session: Session) -> None:
92-
coroutines: t.List[t.Coroutine[t.Any, t.Any, None]] = []
93-
for on_session_change_async_callback in self._on_session_change_async_callbacks:
94-
coroutines.append(on_session_change_async_callback(event, session))
95-
96-
await asyncio.gather(*coroutines)
88+
async def _call_on_session_change_callbacks(self, event: SessionEvent) -> None:
89+
await self._session_dispatcher.dispatch_session_change_async(event)
9790

9891

9992
class SessionMethodsMixin(TimeMethodsMixin):
10093
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
10194
super().__init__(*args, **kwargs)
102-
103-
self._access_jwt: t.Optional[str] = None
104-
self._access_jwt_payload: t.Optional['JwtPayload'] = None
105-
106-
self._refresh_jwt: t.Optional[str] = None
107-
self._refresh_jwt_payload: t.Optional['JwtPayload'] = None
108-
10995
self._session: t.Optional[Session] = None
96+
self._session_dispatcher = SessionDispatcher()
97+
98+
def _register_auth_headers_source(self) -> None:
99+
self.request.add_additional_headers_source(self._get_access_auth_headers)
110100

111101
def _should_refresh_session(self) -> bool:
112-
if not self._access_jwt_payload or not self._access_jwt_payload.exp:
102+
if not self._session or not self._session.access_jwt_payload or not self._session.access_jwt_payload.exp:
113103
raise LoginRequiredError
114104

115-
expired_at = self.get_time_from_timestamp(self._access_jwt_payload.exp)
105+
expired_at = self.get_time_from_timestamp(self._session.access_jwt_payload.exp)
116106
expired_at = expired_at - timedelta(minutes=15) # let's update the token a bit earlier than required
117107

118108
return self.get_current_time() > expired_at
119109

120-
def _set_session_common(self, session: SessionResponse, current_pds: str) -> Session:
121-
self._access_jwt = session.access_jwt
122-
self._access_jwt_payload = get_jwt_payload(session.access_jwt)
110+
def _set_or_update_session(self, session: SessionResponse, pds_endpoint: str) -> 'Session':
111+
if not self._session:
112+
self._session = Session(
113+
access_jwt=session.access_jwt,
114+
refresh_jwt=session.refresh_jwt,
115+
did=session.did,
116+
handle=session.handle,
117+
pds_endpoint=pds_endpoint,
118+
)
119+
self._session_dispatcher.set_session(self._session)
120+
self._register_auth_headers_source()
121+
else:
122+
self._session.access_jwt = session.access_jwt
123+
self._session.refresh_jwt = session.refresh_jwt
124+
self._session.did = session.did
125+
self._session.handle = session.handle
126+
self._session.pds_endpoint = pds_endpoint
123127

124-
self._refresh_jwt = session.refresh_jwt
125-
self._refresh_jwt_payload = get_jwt_payload(session.refresh_jwt)
128+
return self._session
126129

130+
def _set_session_common(self, session: SessionResponse, current_pds: str) -> Session:
127131
pds_endpoint = get_session_pds_endpoint(session)
128132
if not pds_endpoint:
129133
# current_pds ends with xrpc endpoint, but this is not a problem
130134
# overhead is only 4-5 symbols in the exported session string
131135
pds_endpoint = current_pds
132136

133-
self._session = Session(
134-
access_jwt=session.access_jwt,
135-
refresh_jwt=session.refresh_jwt,
136-
did=session.did,
137-
handle=session.handle,
138-
pds_endpoint=pds_endpoint,
139-
)
140-
141-
self._set_auth_headers(session.access_jwt)
142137
self._update_pds_endpoint(pds_endpoint)
138+
return self._set_or_update_session(session, pds_endpoint)
143139

144-
return self._session
140+
def _get_access_auth_headers(self) -> t.Dict[str, str]:
141+
if not self._session:
142+
return {}
145143

146-
@staticmethod
147-
def _get_auth_headers(token: str) -> t.Dict[str, str]:
148-
return {'Authorization': f'Bearer {token}'}
144+
return {'Authorization': f'Bearer {self._session.access_jwt}'}
145+
146+
def _get_refresh_auth_headers(self) -> t.Dict[str, str]:
147+
if not self._session:
148+
return {}
149149

150-
def _set_auth_headers(self, token: str) -> None:
151-
for header_name, header_value in self._get_auth_headers(token).items():
152-
self.request.add_additional_header(header_name, header_value)
150+
return {'Authorization': f'Bearer {self._session.refresh_jwt}'}
153151

154152
def _update_pds_endpoint(self, pds_endpoint: str) -> None:
155153
self.update_base_url(pds_endpoint)

0 commit comments

Comments
 (0)