Skip to content

wrapt + adapter which adds a argument combined with pytest yielding fixture #301

@jankatins

Description

@jankatins

I needed to wrap a pytest fixture (not test!) to let it only run once when running in xdist. For that I needed a adapter, which added a required argument (so that pytest would fetch the required tmp_path_factory fixture value and pass it in) and make it a generator (so pytest would see it as a yielding fixture, not a returning one).

The particular problem is that if I just manipulate the argspec, I end up with a fixture, which pytest identifies as "returning", not "yielding": pytest uses inspect.isgeneratorfunction() (which looks at some code specifics, not return types) to see if a function is a generator, but wrapt creates a function with just pass (=return None) and so gets identified as a non-Generator:

exec(f"def adapter{adapter}: pass", ns, ns)

It would be nice if wrapt would be able to generate a "proper" generator depending on some arguments (I guess async def will be equally problematic).

The code I came up with is below, it basically re-implements

if not callable(adapter):
ns = {}
# Check if the signature argument specification has
# annotations. If it does then we need to remember
# it but also drop it when attempting to manufacture
# a standin adapter function. This is necessary else
# it will try and look up any types referenced in
# the annotations in the empty namespace we use,
# which will fail.
annotations = {}
if not isinstance(adapter, str):
if len(adapter) == 7:
annotations = adapter[-1]
adapter = adapter[:-1]
adapter = formatargspec(*adapter)
exec(f"def adapter{adapter}: pass", ns, ns)
adapter = ns["adapter"]
# Override the annotations for the manufactured
# adapter function so they match the original
# adapter signature argument specification.
if annotations:
adapter.__annotations__ = annotations
with the required generator signature:

RT = TypeVar("RT", bound=BaseModel)


def _additional_fixtures_protocol(tmp_path_factory: pytest.TempPathFactory):  # type: ignore[no-untyped-def]  # noqa: ANN202
    """Protocol to get access to the tmp_path_factory arg spec."""


def combine_args_with_protocol_adapter_factory(wrapped: Callable[..., Any]) -> Callable[..., Any]:
    """Adjust the signature of the wrapped functions with additional arguments from the protocol."""
    # At this point we know that the wrapped function is a fixture, so should only contain args.
    # We also know that the protocol only contains args

    argspec_wrapped = inspect.getfullargspec(wrapped)
    argspec_protocol = inspect.getfullargspec(_additional_fixtures_protocol)
    combined_args = argspec_wrapped.args[:] + argspec_protocol.args[:]

    adapter_spec = formatargspec(
        args=combined_args,
        varkw=argspec_wrapped.varkw,
        defaults=argspec_wrapped.defaults,
        kwonlyargs=argspec_wrapped.kwonlyargs,
        kwonlydefaults=argspec_wrapped.kwonlydefaults,
        varargs=argspec_wrapped.varargs,
        # No annotations, it would fail to compile
    )
    # the current wrapt produces a normal function, no yielding generator,
    # so we have to create one here ourselves with exec :-(
    # We need it because pytest uses inspect.isgeneratorfunction() to decide between generator
    # fixtures and fixtures with return and if we would have a return function, it would never get the
    # data out of the generator :-(
    ns: dict[str, Any] = {}
    exec_(f"def adapter{adapter_spec}: yield", ns, ns)
    adapter = ns["adapter"]
    # the protocol only contains arguments, no return type, so we would not override that...
    annotations = argspec_protocol.annotations.copy()
    # We prefer the annotations from the wrapped function, including the annotation for the return type
    annotations.update(argspec_wrapped.annotations.copy())
    adapter.__annotations__ = annotations
    return adapter


# Decorator modeled after
# https://pytest-xdist.readthedocs.io/en/latest/how-to.html#making-session-scoped-fixtures-execute-only-once
def xdist_run_only_once(  # noqa: PLR0915
    *, return_type: type[RT]
) -> Callable[[Callable[..., Iterator[RT]]], Callable[[pytest.TempPathFactory, str], Iterator[RT]]]:
    """Call a fixture only once despite xdist."""
    worker_id = os.environ.get("PYTEST_XDIST_WORKER", "master")
    if worker_id == "master":
        # not executing with multiple workers or without xdist
        # -> just make the decorator return the original functions
        return lambda x: x

    @wrapt.decorator(adapter=wrapt.adapter_factory(combine_args_with_protocol_adapter_factory))
    def adapted(  # type: ignore[no-untyped-def]  # noqa: PLR0915
        wrapped: Callable[..., Iterator[RT]],
        instance,  # noqa: ANN001
        # One of these already contains the new arguments from the protocol
        args,  # noqa: ANN001
        kwargs,  # noqa: ANN001
    ) -> Iterator[RT]:
        """Inner fixture with the interface of the combined arguments of the original fixture + the protocol."""
        lock_name = f"{wrapped.__module__}.{wrapped.__name__}"

        # The _executer function is a shorter way to pull out the named arguments from args/kwargs
        # no matter if these are in args or kwargs
        def _executer(  # type: ignore[no-untyped-def] # noqa: PLR0915
            tmp_path_factory: pytest.TempPathFactory,
            *_args,  # noqa: ANN002
            **_kwargs,  # noqa: ANN003
        ) -> Iterator[RT]:
            running_fixture: Iterator[Any] | None = None
            # get the temp directory shared by all workers
            # getbasetemp() is a worker specific directory under xdist, so go one down to get the shared one
            root_tmp_dir = tmp_path_factory.getbasetemp().parent

            lock_file = root_tmp_dir / f"{lock_name}.lock"
            info_file = root_tmp_dir / f"{lock_name}.json"
            worker_file = root_tmp_dir / f"{lock_name}.workers"

            def _load_worker_list() -> list[str]:
                if not worker_file.is_file():
                    return []
                return sorted(json.loads(worker_file.read_text())["workers"])

            def _write_worker_list(workers: list[str]) -> None:
                worker_file.write_text(json.dumps({"workers": workers}))

            def _add_worker() -> None:
                worker_id = os.environ["PYTEST_XDIST_WORKER"]
                with FileLock(str(lock_file)):
                    workers = _load_worker_list()
                    workers.append(worker_id)
                    _write_worker_list(workers)

            def _remove_worker() -> None:
                worker_id = os.environ["PYTEST_XDIST_WORKER"]
                with FileLock(str(lock_file)):
                    workers = _load_worker_list()
                    try:
                        workers.remove(worker_id)
                    except ValueError:

                    _write_worker_list(workers)

            with FileLock(str(lock_file)):
                if info_file.is_file():
                    data = return_type.model_validate_json(info_file.read_text())
                else:
                    # The first one actually creates it
                    running_fixture = wrapped(*_args, **_kwargs)
                    data = next(running_fixture)
                    info_file.write_text(data.model_dump_json())

            _add_worker()
            # Only yield when out of the locks!
            yield data
            _remove_worker()

            # We have nothing to do anymore, shut down any resources, but only if
            # - we created them and
            # - only after we are the last worker

            if running_fixture is None:
                return

            start = time.monotonic()
            timeout = 20 * 60  # 20 min
            # Wait for workers to become empty as other worker shut down
            while start + timeout > time.monotonic():
                with FileLock(str(lock_file)):
                    workers = _load_worker_list()
                if len(workers) == 0:
                    break
                time.sleep(1)
            # And now we are the last and can run the fixture clean up and then our own cleanup
            with FileLock(str(lock_file)):
                # We expect that ends with a raised StopIteration
                # BUT we have to return normally as otherwise this gets turned into a RuntimeError
                with contextlib.suppress(StopIteration):
                    next(running_fixture)
                worker_file.unlink(missing_ok=True)
                info_file.unlink(missing_ok=True)

        yield from _executer(*args, **kwargs)

    return adapted

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions