Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions src/editables/redirector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,57 @@
import importlib.util
import sys
from types import ModuleType
from typing import Dict, Optional, Sequence, Union
from typing import Dict, Optional, Sequence, Set, Union

ModulePath = Optional[Sequence[Union[bytes, str]]]


class RedirectingFinder(importlib.abc.MetaPathFinder):
_redirections: Dict[str, str] = {}
_parents: Set[str] = set()

@classmethod
def map_module(cls, name: str, path: str) -> None:
cls._redirections[name] = path
cls._parents.update(cls.parents(name))

@classmethod
def parents(cls, name):
"""
Given a full name, generate all parents.

>>> list(RedirectingFinder.parents('a.b.c.d'))
['a.b.c', 'a.b', 'a']
"""
base, sep, name = name.rpartition('.')
if base:
yield base
yield from cls.parents(base)

@classmethod
def find_spec(
cls, fullname: str, path: ModulePath = None, target: Optional[ModuleType] = None
) -> Optional[importlib.machinery.ModuleSpec]:
if "." in fullname:
return None
if path is not None:
return None
return cls.spec_from_parent(fullname) or cls.spec_from_redirect(fullname)

@classmethod
def spec_from_parent(
cls, fullname: str
) -> Optional[importlib.machinery.ModuleSpec]:
if fullname in cls._parents:
return importlib.util.spec_from_loader(
fullname,
importlib.machinery.NamespaceLoader(
fullname,
path=[],
path_finder=cls.find_spec,
),
)

@classmethod
def spec_from_redirect(
cls, fullname: str
) -> Optional[importlib.machinery.ModuleSpec]:
try:
redir = cls._redirections[fullname]
except KeyError:
Expand Down
20 changes: 20 additions & 0 deletions tests/test_redirects.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,26 @@ def test_redirects(tmp_path):
assert pkg.sub.val == 42


def test_namespace_redirects(tmp_path):
project = tmp_path / "project"
project_files = {
"ns.pkg": {
"__init__.py": "val = 42",
"sub.py": "val = 42",
}
}
build(project, project_files)

with save_import_state():
F.install()
F.map_module("ns.pkg", project / "ns.pkg" / "__init__.py")

import ns.pkg.sub

assert ns.pkg.val == 42
assert ns.pkg.sub.val == 42


def test_cache_invalidation():
F.install()
# assert that the finder matches importlib's expectations
Expand Down