Skip to content

Commit bedd158

Browse files
authored
Fix headers overriding in RequestBase -> get_headers() (#601)
1 parent ffcd7dd commit bedd158

File tree

2 files changed

+147
-6
lines changed

2 files changed

+147
-6
lines changed

packages/atproto_client/request.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,40 @@ class Response:
2626
headers: t.Dict[str, t.Any]
2727

2828

29+
_NormalizedHeaders = t.Dict[str, t.Tuple[str, str]]
30+
31+
32+
def _normalize_headers(headers: t.Dict[str, str]) -> _NormalizedHeaders:
33+
"""Normalize headers by converting keys to lowercase.
34+
35+
Args:
36+
headers: Headers dictionary with the original case.
37+
38+
Returns:
39+
Dictionary with lowercase keys mapped to tuples of (original_key, value).
40+
"""
41+
return {key.lower(): (key, value) for key, value in headers.items()}
42+
43+
44+
def _denormalize_headers(headers: _NormalizedHeaders) -> t.Dict[str, str]:
45+
"""Denormalize headers by converting keys back to their original case.
46+
47+
Args:
48+
headers: Headers dictionary with lowercase keys mapped to tuples of (original_key, value).
49+
50+
Returns:
51+
Dictionary with original keys and values.
52+
"""
53+
return {original_key: value for _, (original_key, value) in headers.items()}
54+
55+
2956
def _convert_headers_to_dict(headers: httpx.Headers) -> t.Dict[str, str]:
30-
"""Convert custom case-insensitive multi-dict of HTTPX to pure dict with lowercased keys.
57+
"""Convert custom case-insensitive multi-dict of HTTPX to pure dict with lowercase keys.
3158
3259
Note:
3360
Concatenate headers into a single comma separated value when a key occurs multiple times.
3461
"""
35-
return dict(headers.items())
62+
return {key.lower(): value for key, value in headers.items()}
3663

3764

3865
def _parse_response(response: httpx.Response) -> Response:
@@ -100,15 +127,19 @@ def get_headers(self, additional_headers: t.Optional[t.Dict[str, str]] = None) -
100127
Returns:
101128
Headers for the request.
102129
"""
103-
headers = {**RequestBase._MANDATORY_HEADERS, **self._additional_headers}
130+
# The order of calling `.update()` matters. It defines the priority of overriding.
131+
headers_lower = {
132+
**_normalize_headers(RequestBase._MANDATORY_HEADERS),
133+
**_normalize_headers(self._additional_headers),
134+
}
104135

105136
for header_source in self._additional_header_sources:
106-
headers.update(header_source())
137+
headers_lower.update(_normalize_headers(header_source()))
107138

108139
if additional_headers:
109-
headers.update(additional_headers)
140+
headers_lower.update(_normalize_headers(additional_headers))
110141

111-
return headers
142+
return _denormalize_headers(headers_lower)
112143

113144
def set_additional_headers(self, headers: t.Dict[str, str]) -> None:
114145
"""Set additional headers for the request.
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from atproto_client.request import RequestBase
2+
3+
4+
def test_get_headers_case_insensitivity() -> None:
5+
"""Test that get_headers handles case-insensitive header names correctly."""
6+
req = RequestBase()
7+
8+
# Add a header with mixed case
9+
req.add_additional_header('Content-Type', 'application/json')
10+
11+
# Try to override with a different case
12+
headers = req.get_headers({'content-type': 'text/plain'})
13+
14+
# Check that the header was properly overridden
15+
assert 'content-type' in headers
16+
assert headers['content-type'] == 'text/plain'
17+
assert 'Content-Type' not in headers # No mixed case keys
18+
19+
# Check that there's only one content-type header (case-insensitive)
20+
content_type_headers = [k for k in headers if k.lower() == 'content-type']
21+
assert len(content_type_headers) == 1
22+
23+
24+
def test_add_additional_header_case_insensitivity() -> None:
25+
"""Test that add_additional_header handles case-insensitive header names correctly."""
26+
req = RequestBase()
27+
28+
# Add a header with mixed case
29+
req.add_additional_header('Content-Type', 'application/json')
30+
31+
# Add the same header with a different case
32+
req.add_additional_header('content-type', 'text/plain')
33+
34+
# Get the headers
35+
headers = req.get_headers()
36+
37+
# Check that the header was properly overridden
38+
assert 'content-type' in headers
39+
assert headers['content-type'] == 'text/plain'
40+
assert 'Content-Type' not in headers # No mixed case keys
41+
42+
# Check that there's only one content-type header (case-insensitive)
43+
content_type_headers = [k for k in headers if k.lower() == 'content-type']
44+
assert len(content_type_headers) == 1
45+
46+
47+
def test_set_additional_headers_case_insensitivity() -> None:
48+
"""Test set_additional_headers."""
49+
req = RequestBase()
50+
51+
# Set headers with a mixed case
52+
req.set_additional_headers(
53+
{'Content-Type': 'application/json', 'AUTHORIZATION': 'Bearer token', 'accept': 'application/json'}
54+
)
55+
56+
# Get the headers
57+
headers = req.get_headers()
58+
59+
# Check that all headers are present
60+
assert 'Content-Type' in headers
61+
assert 'AUTHORIZATION' in headers
62+
assert 'accept' in headers
63+
64+
# Check values
65+
assert headers['Content-Type'] == 'application/json'
66+
assert headers['AUTHORIZATION'] == 'Bearer token'
67+
assert headers['accept'] == 'application/json'
68+
69+
70+
def test_headers_override_with_additional_headers() -> None:
71+
"""Test that additional headers properly override existing headers."""
72+
req = RequestBase()
73+
74+
# Add some headers
75+
req.add_additional_header('content-type', 'application/json')
76+
req.add_additional_header('authorization', 'Bearer token1')
77+
78+
# Override with additional headers
79+
headers = req.get_headers({'Content-Type': 'text/plain', 'AUTHORIZATION': 'Bearer token2'})
80+
81+
# Check that headers were properly overridden
82+
assert headers['Content-Type'] == 'text/plain'
83+
assert headers['AUTHORIZATION'] == 'Bearer token2'
84+
85+
# Check that there are no duplicate headers with different cases
86+
assert len([k for k in headers if k.lower() == 'content-type']) == 1
87+
assert len([k for k in headers if k.lower() == 'authorization']) == 1
88+
89+
90+
def test_headers_from_sources() -> None:
91+
"""Test that headers from sources are properly handled."""
92+
req = RequestBase()
93+
94+
# Add a header source
95+
req.add_additional_headers_source(lambda: {'Content-Type': 'application/json'})
96+
97+
# Add another header source with a different case
98+
req.add_additional_headers_source(lambda: {'content-type': 'text/plain'})
99+
100+
# Get the headers
101+
headers = req.get_headers()
102+
103+
# Check that the last source's value is used
104+
assert headers['content-type'] == 'text/plain'
105+
106+
# Check that the first source's value is not used
107+
assert 'Content-Type' not in headers
108+
109+
# Check that there are no duplicate headers with different cases
110+
assert len([k for k in headers if k.lower() == 'content-type']) == 1

0 commit comments

Comments
 (0)