diff --git a/src/kernels/utils.py b/src/kernels/utils.py index f71855f..3a0e89b 100644 --- a/src/kernels/utils.py +++ b/src/kernels/utils.py @@ -71,6 +71,24 @@ def universal_build_variant() -> str: return "torch-universal" +# Metaclass to allow overriding the `__repr__` method for kernel modules. +class _KernelModuleMeta(type): + def __repr__(self): + return "" + + +# Custom module type to identify dynamically loaded kernel modules. +# Using a subclass lets us distinguish these from regular imports. +class _KernelModuleType(ModuleType, metaclass=_KernelModuleMeta): + """Marker class for modules loaded dynamically from a path.""" + + module_name: str + is_kernel: bool = True + + def __repr__(self): + return f"" + + def import_from_path(module_name: str, file_path: Path) -> ModuleType: # We cannot use the module name as-is, after adding it to `sys.modules`, # it would also be used for other imports. So, we make a module name that @@ -84,6 +102,9 @@ def import_from_path(module_name: str, file_path: Path) -> ModuleType: module = importlib.util.module_from_spec(spec) if module is None: raise ImportError(f"Cannot load module {module_name} from spec") + module.__class__ = _KernelModuleType + assert isinstance(module, _KernelModuleType) # for mypy type checking + module.module_name = module_name sys.modules[module_name] = module spec.loader.exec_module(module) # type: ignore return module