diff --git a/singer_sdk/streams/rest.py b/singer_sdk/streams/rest.py index 942b7294d..5074499d9 100644 --- a/singer_sdk/streams/rest.py +++ b/singer_sdk/streams/rest.py @@ -7,6 +7,7 @@ import decimal import logging import typing as t +from contextlib import contextmanager from functools import cached_property from http import HTTPStatus from urllib.parse import urlparse @@ -105,8 +106,44 @@ def __init__( self._http_headers: dict[str, str] = {} self._http_method = http_method self._requests_session = requests.Session() + self._current_paginator: BaseAPIPaginator | None = None super().__init__(name=name, schema=schema, tap=tap) + @property + def paginator(self) -> BaseAPIPaginator: + """Get the current paginator instance. + + Only available during active stream processing (within request_records). + This allows access to pagination configuration from methods like get_url_params. + + Returns: + The current paginator instance. + + Raises: + RuntimeError: If accessed outside of request_records context. + """ + if self._current_paginator is None: + msg = ( + "Paginator is only available during active stream processing. " + "Access it from methods called within the request_records lifecycle " + "(get_url_params, prepare_request_payload, etc.)." + ) + raise RuntimeError(msg) + return self._current_paginator + + @contextmanager + def _paginator_context(self) -> t.Iterator[BaseAPIPaginator]: + """Context manager for paginator lifecycle. + + Yields: + A fresh paginator instance for this request context. + """ + self._current_paginator = self.get_new_paginator() or SinglePagePaginator() + try: + yield self._current_paginator + finally: + self._current_paginator = None + @staticmethod def _url_encode(val: str | datetime | bool | int | list[str]) -> str: # noqa: FBT001 """Encode the val argument as url-compatible string. @@ -350,6 +387,10 @@ def get_url_params( # noqa: PLR6301 If paging is supported, developers may override with specific paging logic. + The paginator instance is available as ``self.paginator`` during request + processing, allowing access to pagination configuration without duplicating + logic across methods. + If your source needs special handling and, for example, parentheses should not be encoded, you can return a string constructed with :py:func:`urllib.parse.urlencode`: @@ -447,40 +488,40 @@ def request_records(self, context: Context | None) -> t.Iterable[dict]: Yields: An item for every record in the response. """ - paginator = self.get_new_paginator() or SinglePagePaginator() - decorated_request = self.request_decorator(self._request) - pages = 0 - - with metrics.http_request_counter(self.name, self.path) as request_counter: - request_counter.context = context - - while not paginator.finished: - prepared_request = self.prepare_request( - context, - next_page_token=paginator.current_value, - ) - resp = decorated_request(prepared_request, context) - request_counter.increment() - self.update_sync_costs(prepared_request, resp, context) - records = iter(self.parse_response(resp)) - try: - first_record = next(records) - except StopIteration: - if paginator.continue_if_empty(resp): - paginator.advance(resp) - continue - - self.logger.info( - "Pagination stopped after %d pages because no records were " - "found in the last response", - pages, - ) - break - yield first_record - yield from records - pages += 1 + with self._paginator_context() as paginator: + decorated_request = self.request_decorator(self._request) + pages = 0 - paginator.advance(resp) + with metrics.http_request_counter(self.name, self.path) as request_counter: + request_counter.context = context + + while not paginator.finished: + prepared_request = self.prepare_request( + context, + next_page_token=paginator.current_value, + ) + resp = decorated_request(prepared_request, context) + request_counter.increment() + self.update_sync_costs(prepared_request, resp, context) + records = iter(self.parse_response(resp)) + try: + first_record = next(records) + except StopIteration: + if paginator.continue_if_empty(resp): + paginator.advance(resp) + continue + + self.logger.info( + "Pagination stopped after %d pages because no records were " + "found in the last response", + pages, + ) + break + yield first_record + yield from records + pages += 1 + + paginator.advance(resp) def _write_request_duration_log( self, diff --git a/tests/core/rest/test_pagination.py b/tests/core/rest/test_pagination.py index cec1690d9..4a5d3d974 100644 --- a/tests/core/rest/test_pagination.py +++ b/tests/core/rest/test_pagination.py @@ -519,3 +519,347 @@ def _request( with pytest.raises(StopIteration): next(records_iter) + + +def test_paginator_access_from_methods(tap: Tap): + """Test that paginator is accessible from various stream methods.""" + + class TestStream(RESTStream[int]): + """Test stream that accesses paginator from various methods.""" + + name = "test-stream" + url_base = "https://api.test.com" + path = "/items" + records_jsonpath = "$.data[*]" + schema: t.ClassVar = { + "type": "object", + "properties": {"id": {"type": "integer"}}, + } + + # Track calls to verify paginator is accessible + paginator_access_calls: t.ClassVar = [] + + def get_new_paginator(self) -> BaseOffsetPaginator: + return BaseOffsetPaginator(start_value=0, page_size=2) + + def get_url_params(self, context, next_page_token): # noqa: ARG002 + # Verify paginator is accessible during request processing + try: + paginator = self.paginator + self.paginator_access_calls.append( + f"get_url_params: {type(paginator).__name__}" + ) + except RuntimeError as e: + self.paginator_access_calls.append(f"get_url_params: ERROR - {e}") + return {"offset": next_page_token or 0} + + def prepare_request_payload(self, context, next_page_token): # noqa: ARG002 + # Test paginator access from another method + try: + paginator = self.paginator + self.paginator_access_calls.append( + f"prepare_request_payload: {type(paginator).__name__}" + ) + except RuntimeError as e: + self.paginator_access_calls.append( + f"prepare_request_payload: ERROR - {e}" + ) + + def _request(self, prepared_request, context): # noqa: ARG002 + """Mock request that returns test data.""" + + r = Response() + r.status_code = 200 + + parsed = urlparse(prepared_request.url) + query = parse_qs(parsed.query) + offset = int(query.get("offset", ["0"])[0]) + + if offset == 0: + r._content = json.dumps({"data": [{"id": 1}, {"id": 2}]}).encode() + elif offset == 2: + r._content = json.dumps({"data": [{"id": 3}, {"id": 4}]}).encode() + else: + r._content = json.dumps({"data": []}).encode() + + return r + + stream = TestStream(tap=tap) + + # Test that paginator is NOT accessible outside of request_records + with pytest.raises( + RuntimeError, match="only available during active stream processing" + ): + _ = stream.paginator + + # Clear any previous calls + stream.paginator_access_calls = [] + + # Process records which should make paginator accessible + records = list(stream.request_records(context=None)) + + # Verify records were processed correctly + assert len(records) == 4 + assert records[0] == {"id": 1} + assert records[1] == {"id": 2} + assert records[2] == {"id": 3} + assert records[3] == {"id": 4} + + # Verify paginator was accessible from both methods + assert len(stream.paginator_access_calls) >= 2 + assert any( + "get_url_params: BaseOffsetPaginator" in call + for call in stream.paginator_access_calls + ) + assert any( + "prepare_request_payload: BaseOffsetPaginator" in call + for call in stream.paginator_access_calls + ) + + # Verify no error calls + assert all("ERROR" not in call for call in stream.paginator_access_calls) + + # Test that paginator is NOT accessible after request_records completes + with pytest.raises( + RuntimeError, match="only available during active stream processing" + ): + _ = stream.paginator + + +def test_paginator_fresh_per_context(tap: Tap): + """Test that paginator is fresh for each request_records call.""" + + class TestStream(RESTStream[int]): + name = "test-stream" + url_base = "https://api.test.com" + path = "/items" + records_jsonpath = "$.data[*]" + schema: t.ClassVar = { + "type": "object", + "properties": {"id": {"type": "integer"}}, + } + + paginator_creation_count = 0 + + def get_new_paginator(self) -> BaseOffsetPaginator: + # Track how many times get_new_paginator is called + self.paginator_creation_count += 1 + return BaseOffsetPaginator(start_value=0, page_size=1) + + def get_url_params(self, context, next_page_token): # noqa: ARG002 + return {"offset": next_page_token or 0} + + def _request(self, prepared_request, context): # noqa: ARG002 + r = Response() + r.status_code = 200 + + parsed = urlparse(prepared_request.url) + query = parse_qs(parsed.query) + offset = int(query.get("offset", ["0"])[0]) + + if offset == 0: + r._content = json.dumps({"data": [{"id": 1}]}).encode() + else: + # Return empty data to end pagination + r._content = json.dumps({"data": []}).encode() + + return r + + stream = TestStream(tap=tap) + + # Verify counter starts at 0 + assert stream.paginator_creation_count == 0 + + # First call to request_records + list(stream.request_records(context={"user": 1})) + assert stream.paginator_creation_count == 1 + + # Second call to request_records + list(stream.request_records(context={"user": 2})) + assert stream.paginator_creation_count == 2 + + # Verify that get_new_paginator was called for each request_records call + # This proves fresh paginators are created each time + + +def test_paginator_lifecycle_cleanup(tap: Tap): + """Test that paginator is properly cleaned up after use.""" + + class TestStream(RESTStream[int]): + name = "test-stream" + url_base = "https://api.test.com" + path = "/items" + records_jsonpath = "$.data[*]" + schema: t.ClassVar = { + "type": "object", + "properties": {"id": {"type": "integer"}}, + } + + def get_new_paginator(self) -> BaseOffsetPaginator: + return BaseOffsetPaginator(start_value=0, page_size=1) + + def get_url_params(self, context, next_page_token): # noqa: ARG002 + return {"offset": next_page_token or 0} + + def _request(self, prepared_request, context): # noqa: ARG002 + r = Response() + r.status_code = 200 + + parsed = urlparse(prepared_request.url) + query = parse_qs(parsed.query) + offset = int(query.get("offset", ["0"])[0]) + + if offset == 0: + r._content = json.dumps({"data": [{"id": 1}]}).encode() + else: + # Return empty data to end pagination + r._content = json.dumps({"data": []}).encode() + + return r + + stream = TestStream(tap=tap) + + # Verify paginator starts as None + assert stream._current_paginator is None + + # Process records + list(stream.request_records(context=None)) + + # Verify paginator is cleaned up after processing + assert stream._current_paginator is None + + +def test_paginator_access_with_no_paginator(tap: Tap): + """Test behavior when get_new_paginator returns None.""" + + class TestStream(RESTStream): + name = "test-stream" + url_base = "https://api.test.com" + path = "/items" + records_jsonpath = "$.data[*]" + schema: t.ClassVar = { + "type": "object", + "properties": {"id": {"type": "integer"}}, + } + + paginator_type_accessed = None + + def get_new_paginator(self) -> None: + return None + + def get_url_params(self, context, next_page_token): # noqa: ARG002 + # Should still be able to access paginator (will be SinglePagePaginator) + self.paginator_type_accessed = type(self.paginator).__name__ + return {} + + def _request(self, prepared_request, context): # noqa: ARG002 + r = Response() + r.status_code = 200 + r._content = json.dumps({"data": [{"id": 1}]}).encode() + return r + + stream = TestStream(tap=tap) + + # Process records + list(stream.request_records(context=None)) + + # Verify that a SinglePagePaginator was used when get_new_paginator returned None + assert stream.paginator_type_accessed == "SinglePagePaginator" + + +def test_backward_compatibility_existing_patterns(tap: Tap): + """Test that existing pagination patterns continue to work unchanged.""" + + # This test replicates the pattern used in tap-dummyjson + class LegacyStyleStream(RESTStream[int]): + name = "legacy-stream" + url_base = "https://api.test.com" + path = "/items" + records_jsonpath = "$.data[*]" + schema: t.ClassVar = { + "type": "object", + "properties": {"id": {"type": "integer"}}, + } + + PAGE_SIZE = 25 + + def get_new_paginator(self) -> BaseOffsetPaginator: + return BaseOffsetPaginator(start_value=0, page_size=self.PAGE_SIZE) + + def get_url_params(self, context, next_page_token): # noqa: ARG002 + # Traditional pattern - no paginator access, just using next_page_token + return { + "skip": next_page_token or 0, + "limit": self.PAGE_SIZE, # Duplicate constant + } + + def _request(self, prepared_request, context): # noqa: ARG002 + r = Response() + r.status_code = 200 + + parsed = urlparse(prepared_request.url) + query = parse_qs(parsed.query) + skip = int(query.get("skip", ["0"])[0]) + + if skip == 0: + r._content = json.dumps({"data": [{"id": 1}, {"id": 2}]}).encode() + else: + r._content = json.dumps({"data": []}).encode() + + return r + + # This test replicates what could now be done with paginator access + class ModernStyleStream(RESTStream[int]): + name = "modern-stream" + url_base = "https://api.test.com" + path = "/items" + records_jsonpath = "$.data[*]" + schema: t.ClassVar = { + "type": "object", + "properties": {"id": {"type": "integer"}}, + } + + def get_new_paginator(self) -> BaseOffsetPaginator: + return BaseOffsetPaginator(start_value=0, page_size=25) + + def get_url_params(self, context, next_page_token): # noqa: ARG002 + # Modern pattern - access paginator to verify it's available + # during request processing. For demonstration, we can at least + # verify the paginator is accessible + _ = self.paginator # This would raise an error before our changes + return { + "skip": next_page_token or 0, + "limit": 25, # Still need to use constant since _page_size is private + } + + def _request(self, prepared_request, context): # noqa: ARG002 + r = Response() + r.status_code = 200 + + parsed = urlparse(prepared_request.url) + query = parse_qs(parsed.query) + skip = int(query.get("skip", ["0"])[0]) + + if skip == 0: + r._content = json.dumps({"data": [{"id": 1}, {"id": 2}]}).encode() + else: + r._content = json.dumps({"data": []}).encode() + + return r + + # Test both patterns work and produce the same results + legacy_stream = LegacyStyleStream(tap=tap) + modern_stream = ModernStyleStream(tap=tap) + + legacy_records = list(legacy_stream.request_records(context=None)) + modern_records = list(modern_stream.request_records(context=None)) + + # Both should produce identical results + assert legacy_records == modern_records + assert len(legacy_records) == 2 + assert legacy_records[0] == {"id": 1} + assert legacy_records[1] == {"id": 2} + + # The key achievement is that the modern stream successfully accessed + # self.paginator in get_url_params without errors, which demonstrates + # that the paginator is now available across method scopes