Skip to content
7 changes: 7 additions & 0 deletions singer_sdk/helpers/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,13 @@
description="Maximum number of rows in each batch.",
),
).to_dict()
TAP_MAX_PARALLELISM_CONFIG = PropertiesList(
Property(
"max_parallelism",
IntegerType,
description="Max number of streams that can sync in parallel.",
),
).to_dict()


class TargetLoadMethods(str, Enum):
Expand Down
112 changes: 95 additions & 17 deletions singer_sdk/tap_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@

import abc
import contextlib
import logging
import sys
import typing as t
from enum import Enum
from logging.handlers import QueueHandler, QueueListener
from multiprocessing import Manager, Queue

import click
from joblib import Parallel, delayed, parallel_config

from singer_sdk._singerlib import Catalog, StateMessage
from singer_sdk.configuration._dict_config import merge_missing_config_jsonschema
Expand All @@ -22,6 +27,7 @@
from singer_sdk.helpers._util import dump_json, read_json_file
from singer_sdk.helpers.capabilities import (
BATCH_CONFIG,
TAP_MAX_PARALLELISM_CONFIG,
CapabilitiesEnum,
PluginCapabilities,
TapCapabilities,
Expand Down Expand Up @@ -94,6 +100,7 @@ def __init__(
self._input_catalog: Catalog | None = None
self._state: dict[str, Stream] = {}
self._catalog: Catalog | None = None # Tap's working catalog
self._max_parallelism: int | None = self.config.get("max_parallelism")

# Process input catalog
if isinstance(catalog, Catalog):
Expand Down Expand Up @@ -178,6 +185,20 @@ def setup_mapper(self) -> None:
super().setup_mapper()
self.mapper.register_raw_streams_from_catalog(self.catalog)

@property
def max_parallelism(self) -> int:
"""Get max parallel sinks.

The default is None if not overridden.

Returns:
Max number of streams that can be synced in parallel.
"""
if self._max_parallelism in {0, 1}:
self._max_parallelism = None

return self._max_parallelism

@classproperty
def capabilities(self) -> list[CapabilitiesEnum]: # noqa: PLR6301
"""Get tap capabilities.
Expand Down Expand Up @@ -216,6 +237,9 @@ def append_builtin_config(cls: type[PluginBase], config_jsonschema: dict) -> Non
capabilities = cls.capabilities
if PluginCapabilities.BATCH in capabilities:
merge_missing_config_jsonschema(BATCH_CONFIG, config_jsonschema)
merge_missing_config_jsonschema(
TAP_MAX_PARALLELISM_CONFIG, config_jsonschema
)

# Connection and sync tests:

Expand Down Expand Up @@ -440,31 +464,85 @@ def _set_compatible_replication_methods(self) -> None:

# Sync methods

@t.final
def sync_one(
self,
stream: Stream,
log_level: logging.Logger | None = None,
log_queue: Queue | None = None,
) -> None:
"""Sync a single stream.

Args:
stream: The stream that your would like to sync.
log_level: The logging level used by Tap.logger.
log_queue: Multiprocess Queue used by the listener.

This is a link to a logging example for joblib.
https://github.com/joblib/joblib/issues/1017
"""
if self.max_parallelism is not None and not self.logger.hasHandlers():
queue_handler = QueueHandler(log_queue)
self.logger.addHandler(queue_handler)
self.logger.setLevel(log_level)
self.metrics_logger.addHandler(queue_handler)
self.metrics_logger.setLevel(log_level)

if not stream.selected and not stream.has_selected_descendents:
self.logger.info("Skipping deselected stream '%s'.", stream.name)
return

if stream.parent_stream_type:
self.logger.debug(
"Child stream '%s' is expected to be called "
"by parent stream '%s'. "
"Skipping direct invocation.",
type(stream).__name__,
stream.parent_stream_type.__name__,
)
return

stream.sync()
stream.finalize_state_progress_markers()

@t.final
def sync_all(self) -> None:
"""Sync all streams."""
self._reset_state_progress_markers()
self._set_compatible_replication_methods()
self.write_message(StateMessage(value=self.state))

stream: Stream
for stream in self.streams.values():
if not stream.selected and not stream.has_selected_descendents:
self.logger.info("Skipping deselected stream '%s'.", stream.name)
continue

if stream.parent_stream_type:
self.logger.debug(
"Child stream '%s' is expected to be called "
"by parent stream '%s'. "
"Skipping direct invocation.",
type(stream).__name__,
stream.parent_stream_type.__name__,
if self.max_parallelism is None:
stream: Stream
for stream in self.streams.values():
self.sync_one(stream=stream)
else:
with Manager() as manager:
# Prepare logger for parallel processes
console_handler = logging.StreamHandler(sys.stderr)
console_formatter = logging.Formatter(
fmt="{asctime:23s} | {levelname:8s} | {name:20s} | {message}",
style="{",
)
continue

stream.sync()
stream.finalize_state_progress_markers()
console_handler.setFormatter(console_formatter)
self.logger.addHandler(console_handler)
log_queue = manager.Queue()
listener = QueueListener(log_queue, *self.logger.handlers)
listener.start()
with parallel_config(
backend="loky",
prefer="processes",
n_jobs=self.max_parallelism,
), Parallel() as parallel:
parallel(
delayed(self.sync_one)(
stream,
log_queue=log_queue,
log_level=self.logger.getEffectiveLevel(),
)
for stream in self.streams.values()
)
listener.stop()

# this second loop is needed for all streams to print out their costs
# including child streams which are otherwise skipped in the loop above
Expand Down