diff --git a/flashinfer_bench/compile/builders/__init__.py b/flashinfer_bench/compile/builders/__init__.py index bd4cd21..756a62c 100644 --- a/flashinfer_bench/compile/builders/__init__.py +++ b/flashinfer_bench/compile/builders/__init__.py @@ -1,6 +1,6 @@ from .cuda_builder import CUDABuilder from .python_builder import PythonBuilder from .triton_builder import TritonBuilder -from .tvm_ffi_builder import TVMFFIBuilder +from .tvm_ffi_builder import TvmFfiBuilder -__all__ = ["CUDABuilder", "PythonBuilder", "TritonBuilder", "TVMFFIBuilder"] +__all__ = ["CUDABuilder", "PythonBuilder", "TritonBuilder", "TvmFfiBuilder"] diff --git a/flashinfer_bench/compile/builders/tvm_ffi_builder.py b/flashinfer_bench/compile/builders/tvm_ffi_builder.py index d258b89..8ad9df2 100644 --- a/flashinfer_bench/compile/builders/tvm_ffi_builder.py +++ b/flashinfer_bench/compile/builders/tvm_ffi_builder.py @@ -3,10 +3,12 @@ from __future__ import annotations import logging +from enum import Enum from pathlib import Path from typing import Any, Dict, List, Tuple import tvm_ffi +from tvm_ffi.utils import FileLock from flashinfer_bench.compile.builder import Builder, BuildError, create_pkg_name from flashinfer_bench.compile.runnable import Runnable, TVMFFIRunnable @@ -15,78 +17,302 @@ logger = logging.getLogger(__name__) -CUDA_EXTENSIONS = [".cu"] -CPP_EXTENSIONS = [".cpp", ".cc", ".cxx", ".c"] +# File extension mappings for source file classification +CUDA_EXTENSIONS = [".cu"] # CUDA source files +CPP_EXTENSIONS = [".cpp", ".cc", ".cxx", ".c"] # C/C++ source files -class TVMFFIBuilder(Builder): - """Builder using TVM-FFI with automatic caching and multi-process sharing. +class Language(Enum): + """Enum representing source code languages supported by the builder.""" - Build strategy: - 1. Check if .so exists in cache (multi-process safe) - 2. If not, compile with tvm_ffi.cpp.build_inline() to cache - 3. Load with tvm_ffi.load_module() + CUDA = "cuda" + """The solution's language is CUDA""" + CPP = "cpp" + """The solution's language is C/C++""" - Benefits: - - Multi-process benchmark: Only first process compiles, others load from cache - - Cross-framework: Same .so works with PyTorch, JAX, CuPy (DLPack) - - No JIT/AOT distinction: Smart caching handles both cases + +class TvmFfiBuilder(Builder): + """Builder using TVM-FFI with automatic caching and supports multi-process and multi-threaded + compilation. The result is framework agnostic and supports DLPack interop with PyTorch, JAX, + etc. + + Cache logic: If the builder is asked to build the same solution again, it will return the cached + result. If another builder is asking to build the same solution, as long as the build directory + exists, it will return the cached result. + + The solution to compile should be written in destination-passing style, i.e. the function + should take the input tensors and the output tensors as arguments. + + Examples + -------- + >>> builder = TVMFFIBuilder() + >>> runnable = builder.build(definition, solution) + >>> output = runnable(x=input_tensor) # Allocates and returns output + >>> runnable.call_dest(x=input_tensor, output=output_tensor) # Destination-passing style """ + _BUILD_DIR_NAME = "tvm_ffi" + """Subdirectory under FIB_CACHE_PATH where build artifacts are stored""" + + _LOCK_FILE_NAME = "flashinfer_bench_tvm_ffi_lock" + """File lock name for multi-process synchronization during compilation""" + + _KEY_PREFIX = "tvm_ffi_" + """Prefix for cache keys to avoid collisions with other builders""" + def __init__(self) -> None: + """Initialize the TVMFFIBuilder. + + Sets up empty dictionaries for future extensibility with extra + include paths and linker flags (currently unused). + """ super().__init__() self._extra_include_paths: Dict[str, str] = {} self._extra_ldflags: Dict[str, List[str]] = {} def can_build(self, sol: Solution) -> bool: + """Check if this builder can build the given solution. + + Parameters + ---------- + sol : Solution + Solution to check + + Returns + ------- + bool + True if solution language is CUDA (includes both .cu and .cpp files) + """ return sol.spec.language == SupportedLanguages.CUDA def _make_key(self, solution: Solution) -> str: - return f"tvm_ffi_{create_pkg_name(solution)}" + """Generate unique cache key for a solution. + + Parameters + ---------- + solution : Solution + Solution to generate key for + + Returns + ------- + str + Unique key combining builder name and solution package name + """ + return self._KEY_PREFIX + create_pkg_name(solution) def _make_closer(self): + """Create a closer function for resource cleanup. + + Returns + ------- + callable + No-op closer since TVM-FFI handles cleanup automatically + """ return lambda: None def _get_build_path(self, key: str) -> Path: - return get_fib_cache_path() / "tvm_ffi" / key + """Get the build directory path for a given cache key. + + Parameters + ---------- + key : str + Unique cache key for the solution + + Returns + ------- + Path + Directory path where build artifacts will be stored + """ + return get_fib_cache_path() / self._BUILD_DIR_NAME / key + + def _check_sources(self, path: Path, key: str, sol: Solution) -> bool: + """Check if the source code is vaild, and if the cached .so can be used by comparing source + files and .so existence. + + Returns True (can use cached .so) only if: + 1. The compiled .so file exists + 2. All source files exist with identical content + + Parameters + ---------- + path : Path + Build directory path + key : str + Unique key for this solution (used to find .so file) + sol : Solution + Solution containing source files + + Returns + ------- + can_use_cached : bool + True if the cached .so can be used, False if compilation is needed + """ + # Check if build directory exists + if not path.exists(): + return False + elif not path.is_dir(): + raise BuildError(f"Build directory exists but is not a directory: {path}") + + # Check if .so exists + so_path = path / f"{key}.so" + if not so_path.is_file(): + return False + + # Check if all files exist and content is identical + for src in sol.sources: + # Defensive assertion: the path in the solution should be validated by the Solution + # model validator, but we add this defensive assertion to be safe. + src_path_obj = Path(src.path) + assert not src_path_obj.is_absolute(), f"Absolute path detected: {src.path}" + assert ".." not in src_path_obj.parts, f"Path traversal detected: {src.path}" + + src_path = path / src.path + + if not src_path.exists(): + return False + elif not src_path.is_file(): + raise BuildError(f"Source path exists but is not a file: {src_path}") + + if src_path.read_text() != src.content: + return False + + # All checks passed: can use cached .so + return True + + def _detect_language(self, sol: Solution) -> Language: + """Detect source language based on file extensions. + + Parameters + ---------- + sol : Solution + Solution containing source files + + Returns + ------- + Language + CUDA if any .cu files present, otherwise CPP + + Raises + ------ + BuildError + If no valid source files found + """ + has_cuda = False + has_cpp = False + + for src in sol.sources: + path_str = str(src.path) + if path_str.endswith(tuple(CUDA_EXTENSIONS)): + has_cuda = True + elif path_str.endswith(tuple(CPP_EXTENSIONS)): + has_cpp = True + + if not has_cuda and not has_cpp: + raise BuildError("No CUDA or C++ sources found") + + return Language.CUDA if has_cuda else Language.CPP def _write_sources(self, path: Path, sol: Solution) -> Tuple[List[str], List[str]]: - """Extract and write all source files to the given path.""" + """Write all source files to build directory and collect file paths. + + Creates parent directories as needed for files in subdirectories. + Overwrites files unconditionally (caller already determined a full build is needed). + + Parameters + ---------- + path : Path + Build directory where source files will be written + sol : Solution + Solution containing source files to write + + Returns + ------- + cpp_files : List[str] + List of C++ source file paths + cuda_files : List[str] + List of CUDA source file paths + """ path.mkdir(parents=True, exist_ok=True) cpp_files: List[str] = [] cuda_files: List[str] = [] + for src in sol.sources: + # Defensive assertion: path should be validated at Solution creation time + src_path_obj = Path(src.path) + assert not src_path_obj.is_absolute(), f"Absolute path detected: {src.path}" + assert ".." not in src_path_obj.parts, f"Path traversal detected: {src.path}" + src_path = path / src.path - if src_path.is_dir(): - raise BuildError(f"Source path is a directory: {src_path}") + # Ensure parent directories exist + src_path.parent.mkdir(parents=True, exist_ok=True) + + # Write source file src_path.write_text(src.content) - if str(src_path).endswith(tuple(CPP_EXTENSIONS)): - cpp_files.append(str(src_path)) - elif str(src_path).endswith(tuple(CUDA_EXTENSIONS)): - cuda_files.append(str(src_path)) + # Collect file path by extension + path_str = str(src_path) + if path_str.endswith(tuple(CPP_EXTENSIONS)): + cpp_files.append(path_str) + elif path_str.endswith(tuple(CUDA_EXTENSIONS)): + cuda_files.append(path_str) - if len(cpp_files) == 0 and len(cuda_files) == 0: - raise BuildError("No sources found") return cpp_files, cuda_files - def _get_language(self, cpp_files: List[str], cuda_files: List[str]) -> str: - return "cuda" if len(cuda_files) > 0 else "cpp" - def _get_entry_symbol(self, sol: Solution) -> str: - """Extract function symbol from entry_point.""" + """Extract function symbol from entry_point. + + Parameters + ---------- + sol : Solution + Solution with entry_point in format 'file.ext::symbol' + + Returns + ------- + str + The function symbol name to be loaded from the compiled module + + Raises + ------ + BuildError + If entry_point format is invalid (missing '::' separator) + """ entry_point = sol.spec.entry_point if "::" not in entry_point: raise BuildError( - f"Invalid entry_point format: {entry_point}. Expected 'file.cu::symbol'" + f"Invalid entry_point format: {entry_point}. Expected 'file.extension::symbol'" ) return entry_point.split("::")[-1] def _make_runnable( self, mod: tvm_ffi.Module, entry_symbol: str, defn: Definition, metadata: Dict[str, Any] - ) -> Runnable: - """Create Runnable from TVM-FFI module.""" + ) -> TVMFFIRunnable: + """Create Runnable from TVM-FFI module. + + Wraps the compiled function with a keyword argument adapter that matches + the definition's input/output interface. + + Parameters + ---------- + mod : tvm_ffi.Module + Loaded TVM-FFI module containing the compiled function + entry_symbol : str + Name of the function to extract from the module + defn : Definition + Definition specifying the function interface + metadata : Dict[str, Any] + Metadata about the build (language, paths, etc.) + + Returns + ------- + TVMFFIRunnable + Runnable wrapper that handles tensor allocation and keyword arguments + + Raises + ------ + BuildError + If the entry_symbol is not found in the module + """ try: fn = getattr(mod, entry_symbol) except AttributeError as e: @@ -104,25 +330,64 @@ def _kw_adapter(**kwargs): ) def _build(self, defn: Definition, sol: Solution) -> Runnable: - """Build with automatic caching - compile once, load from cache afterwards.""" + """Build with automatic caching - compile once, load from cache afterwards. + + This method implements intelligent caching: + 1. Checks if a compiled .so file already exists + 2. If not, writes source files and compiles them + 3. Loads the compiled module (from cache or fresh build) + 4. Returns a runnable wrapper + + The caching is multi-process safe, enabling efficient parallel benchmarking. + + Parameters + ---------- + defn : Definition + Problem definition specifying inputs/outputs + sol : Solution + Solution containing source code and build specification + + Returns + ------- + Runnable + TVMFFIRunnable that can be called with input tensors + + Raises + ------ + BuildError + If compilation fails, module loading fails, or entry point is invalid + """ key = self._make_key(sol) build_path = self._get_build_path(key) entry_symbol = self._get_entry_symbol(sol) - cpp_files, cuda_files = self._write_sources(build_path, sol) - language = self._get_language(cpp_files, cuda_files) - extra_include_paths = [str(build_path)] + language = self._detect_language(sol) + can_use_cached = self._check_sources(build_path, key, sol) - try: - # Use build_inline instead of build to - output_lib_path = tvm_ffi.cpp.build( - name=key, - cpp_files=cpp_files, - cuda_files=cuda_files, - extra_include_paths=extra_include_paths, - build_directory=build_path, - ) - except Exception as e: - raise BuildError(f"TVM-FFI compilation failed for '{sol.name}': {e}") from e + # Check if cached .so can be used + # This checking and rebuilding is thread-safe through the FileLock + if can_use_cached: + output_lib_path = str(build_path / f"{key}.so") + else: + # Ensure build directory exists before creating file lock + build_path.mkdir(parents=True, exist_ok=True) + with FileLock(build_path / self._LOCK_FILE_NAME): + # Double-check after acquiring lock (another process may have built it) + if self._check_sources(build_path, key, sol): + output_lib_path = str(build_path / f"{key}.so") + else: + cpp_files, cuda_files = self._write_sources(build_path, sol) + extra_include_paths = [str(build_path)] + try: + # Compile sources to shared library + output_lib_path = tvm_ffi.cpp.build( + name=key, + cpp_files=cpp_files, + cuda_files=cuda_files, + extra_include_paths=extra_include_paths, + build_directory=build_path, + ) + except Exception as e: + raise BuildError(f"TVM-FFI compilation failed for '{sol.name}': {e}") from e # Load the compiled module try: @@ -130,10 +395,11 @@ def _build(self, defn: Definition, sol: Solution) -> Runnable: except Exception as e: raise BuildError(f"Failed to load compiled module: {e}") from e + # Create metadata for the runnable metadata = { "definition": defn.name, "solution": sol.name, - "language": language, + "language": language.value, "binding": "tvm_ffi", "key": key, "symbol": entry_symbol, diff --git a/flashinfer_bench/compile/registry.py b/flashinfer_bench/compile/registry.py index 0e108e7..6636c1e 100644 --- a/flashinfer_bench/compile/registry.py +++ b/flashinfer_bench/compile/registry.py @@ -52,11 +52,11 @@ def build_reference(self, defn: Definition) -> Runnable: def get_builder_registry() -> BuilderRegistry: global _registry if _registry is None: - from .builders import CUDABuilder, PythonBuilder, TritonBuilder, TVMFFIBuilder + from .builders import CUDABuilder, PythonBuilder, TritonBuilder, TvmFfiBuilder py = PythonBuilder() triton = TritonBuilder(py_builder=py) - tvm_ffi = TVMFFIBuilder() + tvm_ffi = TvmFfiBuilder() cuda = CUDABuilder() # Fallback for backward compatibility # Priority: Python > Triton > TVM-FFI > CUDA (pybind11) diff --git a/flashinfer_bench/compile/runnable.py b/flashinfer_bench/compile/runnable.py index 5eb26be..fe309df 100644 --- a/flashinfer_bench/compile/runnable.py +++ b/flashinfer_bench/compile/runnable.py @@ -59,7 +59,13 @@ def __call__(self, **kwargs: Any) -> Any: ) output_shapes = self._definition.get_output_shapes(var_values) output_tensors: Dict[str, torch.Tensor] = {} - device = next(iter(kwargs.values())).device if len(kwargs) > 0 else "cpu" + + # Determine device from input tensors + devices = {v.device for v in kwargs.values() if hasattr(v, "device")} + if len(devices) > 1: + raise ValueError("All input tensors must be on the same device") + device = devices.pop() if devices else "cpu" + for name, shape in output_shapes.items(): output_tensors[name] = torch.empty( shape, dtype=dtype_str_to_torch_dtype(self._definition.outputs[name].dtype) @@ -75,8 +81,3 @@ def __call__(self, **kwargs: Any) -> Any: def call_dest(self, **kwargs: Any) -> None: """Call the underlying function with destination passing style.""" self._fn(**kwargs) - - def close(self) -> None: - if self._closer: - self._closer() - self._closer = None diff --git a/flashinfer_bench/data/solution.py b/flashinfer_bench/data/solution.py index 1b54706..0c3d3fc 100644 --- a/flashinfer_bench/data/solution.py +++ b/flashinfer_bench/data/solution.py @@ -1,7 +1,7 @@ """Strong-typed data definitions for solution implementations.""" -import ast from enum import Enum +from pathlib import Path from typing import List, Optional from pydantic import Field, model_validator @@ -34,26 +34,27 @@ class SourceFile(BaseModelWithDocstrings): path: NonEmptyString """The relative path of the file, including its name and extension (e.g., 'src/kernel.cu', 'main.py'). When compiling the solution, a temporary solution source directory will be - created, and the file will be placed according to this path.""" + created, and the file will be placed according to this path. The path should not contain + parent directory traversal ("..").""" content: NonEmptyString """The complete text content of the source file.""" @model_validator(mode="after") - def _validate_python_syntax(self) -> "SourceFile": - """Validate Python syntax for .py files. + def _validate_source_path(self) -> "SourceFile": + """Validate source path for security. Raises ------ ValueError - If the file is a Python file and contains invalid syntax. + If the path contains security issues (absolute paths or path traversal). """ - if self.path.endswith(".py"): - try: - ast.parse(self.content, mode="exec") - except SyntaxError as e: - raise ValueError(f"SourceFile content must be valid Python code: {e}") from e - - # TODO(shanli): syntax validation for other languages + src_path = Path(self.path) + if src_path.is_absolute(): + raise ValueError(f"Invalid source path (absolute path not allowed): {self.path}") + if ".." in src_path.parts: + raise ValueError( + f"Invalid source path (parent directory traversal not allowed): {self.path}" + ) return self @@ -85,9 +86,11 @@ def _validate_entry_point(self) -> "BuildSpec": ValueError If entry_point doesn't follow the required format. """ - if "::" not in self.entry_point: - raise ValueError("spec.entry_point must be '::'") - # TODO(shanli): validations against entry file existence and function existence + if self.entry_point.count("::") != 1: + raise ValueError( + f"Invalid entry point format: {self.entry_point}. Expected " + '"::".' + ) return self @@ -115,16 +118,16 @@ class Solution(BaseModelWithDocstrings): @model_validator(mode="after") def _validate_source_path_entry_point(self) -> "Solution": - """Validate that all source file paths are unique. + """Validate source file paths for uniqueness and entry file existence. Raises ------ ValueError - If duplicate source file paths are found, or the entry point file is not found in the - sources. + If duplicate source file paths are found or the entry file is not found in the sources. """ seen_paths = set() for source in self.sources: + # Check for duplicates if source.path in seen_paths: raise ValueError(f"Duplicate source path '{source.path}'") seen_paths.add(source.path) @@ -133,7 +136,6 @@ def _validate_source_path_entry_point(self) -> "Solution": if entry_file not in seen_paths: raise ValueError(f"Entry source file '{entry_file}' not found in sources") - # TODO(shanli): stronger validation for entry file and function return self diff --git a/pyproject.toml b/pyproject.toml index 0bfe5a1..b68fe09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,7 +93,7 @@ strict = true [tool.pytest.ini_options] testpaths = ["tests"] python_files = ["test_*.py"] -addopts = "-ra" +addopts = "-rA --durations=0 --ignore=3rdparty" [tool.setuptools_scm] version_scheme = "python-simplified-semver" diff --git a/tests/compile/test_tvm_ffi_builder.py b/tests/compile/test_tvm_ffi_builder.py index f84a936..e0316a7 100644 --- a/tests/compile/test_tvm_ffi_builder.py +++ b/tests/compile/test_tvm_ffi_builder.py @@ -7,18 +7,28 @@ import torch from flashinfer_bench.compile.builder import BuildError -from flashinfer_bench.compile.builders.tvm_ffi_builder import TVMFFIBuilder +from flashinfer_bench.compile.builders.tvm_ffi_builder import TvmFfiBuilder from flashinfer_bench.data import BuildSpec, Definition, Solution, SourceFile, SupportedLanguages + +@pytest.fixture(autouse=True) +def isolated_cache(tmp_path, monkeypatch): + """Use isolated temporary directory for cache in all tests. + + This fixture automatically sets FIB_CACHE_PATH to a unique temporary + directory for each test, preventing cache pollution between tests. + """ + cache_dir = tmp_path / "cache" + cache_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setenv("FIB_CACHE_PATH", str(cache_dir)) + return cache_dir + + # ============================================================================ # CPU Tests # ============================================================================ - -def test_build_cpp_cpu() -> None: - """Test building and running a simple CPU kernel.""" - # CPU kernel source - destination passing style - cpp_source = """ +CPP_SOURCE = """ #include #include #include @@ -44,6 +54,44 @@ def test_build_cpp_cpu() -> None: TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cpu, add_one_cpu); """ +CUDA_SOURCE = """ +#include +#include +#include +#include +#include +#include + +__global__ void add_one_kernel(const float* input, float* output, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + output[idx] = input[idx] + 1.0f; + } +} + +void add_one_cuda(tvm::ffi::TensorView x, tvm::ffi::TensorView output) { + TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor"; + TVM_FFI_ICHECK(output.ndim() == 1) << "output must be a 1D tensor"; + TVM_FFI_ICHECK(x.size(0) == output.size(0)) << "x and output must have the same size"; + + int n = x.size(0); + int threads = 256; + int blocks = (n + threads - 1) / threads; + + add_one_kernel<<>>( + static_cast(x.data_ptr()), + static_cast(output.data_ptr()), + n + ); + cudaDeviceSynchronize(); +} + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cuda, add_one_cuda); +""" + + +def test_build_cpp_cpu() -> None: + """Test building and running a simple CPU kernel.""" # Create definition and solution n = 5 definition = Definition( @@ -66,12 +114,12 @@ def test_build_cpp_cpu() -> None: target_hardware=["cpu"], entry_point="kernel.cpp::add_one_cpu", ), - sources=[SourceFile(path="kernel.cpp", content=cpp_source)], + sources=[SourceFile(path="kernel.cpp", content=CPP_SOURCE)], description="Simple CPU add kernel", ) # Build and run - builder = TVMFFIBuilder() + builder = TvmFfiBuilder() runnable = builder.build(definition, solution) # Test execution with torch tensors - runnable returns output @@ -93,42 +141,6 @@ def test_build_cuda_gpu() -> None: """Test building and running a simple CUDA kernel.""" import torch - # CUDA kernel source - destination passing style - cuda_source = """ -#include -#include -#include -#include -#include -#include - -__global__ void add_one_kernel(const float* input, float* output, int n) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n) { - output[idx] = input[idx] + 1.0f; - } -} - -void add_one_cuda(tvm::ffi::TensorView x, tvm::ffi::TensorView output) { - TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor"; - TVM_FFI_ICHECK(output.ndim() == 1) << "output must be a 1D tensor"; - TVM_FFI_ICHECK(x.size(0) == output.size(0)) << "x and output must have the same size"; - - int n = x.size(0); - int threads = 256; - int blocks = (n + threads - 1) / threads; - - add_one_kernel<<>>( - static_cast(x.data_ptr()), - static_cast(output.data_ptr()), - n - ); - cudaDeviceSynchronize(); -} - -TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cuda, add_one_cuda); -""" - definition = Definition( name="test_add_one_cuda", op_type="test", @@ -149,12 +161,12 @@ def test_build_cuda_gpu() -> None: target_hardware=["gpu"], entry_point="kernel.cu::add_one_cuda", ), - sources=[SourceFile(path="kernel.cu", content=cuda_source)], + sources=[SourceFile(path="kernel.cu", content=CUDA_SOURCE)], description="Simple CUDA add kernel", ) # Build and run - builder = TVMFFIBuilder() + builder = TvmFfiBuilder() runnable = builder.build(definition, solution) # Test execution with torch tensors - runnable returns output @@ -174,7 +186,7 @@ def test_build_cuda_gpu() -> None: def test_can_build_cuda() -> None: """Test that TVMFFIBuilder can build CUDA solutions.""" - builder = TVMFFIBuilder() + builder = TvmFfiBuilder() cuda_solution = Solution( name="test_cuda", @@ -194,7 +206,7 @@ def test_can_build_cuda() -> None: def test_can_build_non_cuda() -> None: """Test that TVMFFIBuilder rejects non-CUDA solutions.""" - builder = TVMFFIBuilder() + builder = TvmFfiBuilder() python_solution = Solution( name="test_python", @@ -210,49 +222,34 @@ def test_can_build_non_cuda() -> None: assert not builder.can_build(python_solution) -def test_caching_cpu() -> None: +def test_caching_builder_level() -> None: """Test that compiled .so is cached and reused for CPU kernels.""" - cpp_source = """ -#include -#include -#include - -void add_two_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView output) { - for (int i = 0; i < x.size(0); ++i) { - static_cast(output.data_ptr())[i] = - static_cast(x.data_ptr())[i] + 2.0f; - } -} - -TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_two_cpu, add_two_cpu); -""" - definition = Definition( - name="test_add_two_cpu", + name="test_add_one_cpu_cached", op_type="test", description="Test CPU caching", axes={"n": {"type": "const", "value": 5}}, constraints=[], inputs={"x": {"shape": ["n"], "dtype": "float32"}}, outputs={"output": {"shape": ["n"], "dtype": "float32"}}, - reference="def run(x): return x + 2", + reference="def run(x): return x + 1", ) solution = Solution( - name="test_add_two_cpu_cached", - definition="test_add_two_cpu", + name="test_add_one_cpu_cached", + definition="test_add_one_cpu_cached", author="test", spec=BuildSpec( language=SupportedLanguages.CUDA, target_hardware=["cpu"], - entry_point="kernel.cpp::add_two_cpu", + entry_point="kernel.cpp::add_one_cpu", ), - sources=[SourceFile(path="kernel.cpp", content=cpp_source)], + sources=[SourceFile(path="kernel.cpp", content=CPP_SOURCE)], description="CPU caching test", ) # First build - builder = TVMFFIBuilder() + builder = TvmFfiBuilder() time_start = time.monotonic() runnable1 = builder.build(definition, solution) time_end = time.monotonic() @@ -271,53 +268,88 @@ def test_caching_cpu() -> None: output2 = runnable2(x=input_tensor) torch.testing.assert_close(output1, output2, rtol=1e-5, atol=1e-5) - torch.testing.assert_close(output1, input_tensor + 2.0, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(output1, input_tensor + 1.0, rtol=1e-5, atol=1e-5) -def test_call_dest_cpu() -> None: - """Test calling call_dest directly with pre-allocated output tensors.""" - cpp_source = """ -#include -#include -#include +def test_caching_cross_builder() -> None: + """Test that compiled .so is cached and reused for CPU kernels.""" + definition = Definition( + name="test_add_one_cpu_cached", + op_type="test", + description="Test CPU caching", + axes={"n": {"type": "const", "value": 5}}, + constraints=[], + inputs={"x": {"shape": ["n"], "dtype": "float32"}}, + outputs={"output": {"shape": ["n"], "dtype": "float32"}}, + reference="def run(x): return x + 1", + ) -void multiply_by_two(tvm::ffi::TensorView x, tvm::ffi::TensorView output) { - for (int i = 0; i < x.size(0); ++i) { - static_cast(output.data_ptr())[i] = - static_cast(x.data_ptr())[i] * 2.0f; - } -} + solution = Solution( + name="test_add_one_cpu_cached", + definition="test_add_one_cpu_cached", + author="test", + spec=BuildSpec( + language=SupportedLanguages.CUDA, + target_hardware=["cpu"], + entry_point="kernel.cpp::add_one_cpu", + ), + sources=[SourceFile(path="kernel.cpp", content=CPP_SOURCE)], + description="CPU caching test", + ) -TVM_FFI_DLL_EXPORT_TYPED_FUNC(multiply_by_two, multiply_by_two); -""" + # First build + builder1 = TvmFfiBuilder() + time_start = time.monotonic() + runnable1 = builder1.build(definition, solution) + time_end = time.monotonic() + print(f"Time taken to build: {(time_end - time_start) * 1000} ms") + + # Second build should load from cache + builder2 = TvmFfiBuilder() + time_start = time.monotonic() + runnable2 = builder2.build(definition, solution) + time_end = time.monotonic() + print(f"Time taken to load from cache: {(time_end - time_start) * 1000} ms") + + # Both should produce the same result + input_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cpu", dtype=torch.float32) + + output1 = runnable1(x=input_tensor) + output2 = runnable2(x=input_tensor) + + torch.testing.assert_close(output1, output2, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(output1, input_tensor + 1.0, rtol=1e-5, atol=1e-5) + +def test_call_dest_cpu() -> None: + """Test calling call_dest directly with pre-allocated output tensors.""" n = 4 definition = Definition( - name="test_multiply_by_two", + name="test_add_one_cpu", op_type="test", - description="Test multiply by two", + description="Test add one", axes={"n": {"type": "const", "value": n}}, constraints=[], inputs={"x": {"shape": ["n"], "dtype": "float32"}}, outputs={"output": {"shape": ["n"], "dtype": "float32"}}, - reference="def run(x): return x * 2", + reference="def run(x): return x + 1", ) solution = Solution( - name="test_multiply_by_two_impl", - definition="test_multiply_by_two", + name="test_add_one_cpu_impl", + definition="test_add_one_cpu", author="test", spec=BuildSpec( language=SupportedLanguages.CUDA, target_hardware=["cpu"], - entry_point="kernel.cpp::multiply_by_two", + entry_point="kernel.cpp::add_one_cpu", ), - sources=[SourceFile(path="kernel.cpp", content=cpp_source)], - description="Multiply by two kernel", + sources=[SourceFile(path="kernel.cpp", content=CPP_SOURCE)], + description="Add one kernel", ) # Build - builder = TVMFFIBuilder() + builder = TvmFfiBuilder() runnable = builder.build(definition, solution) # Manually allocate input and output tensors @@ -328,7 +360,7 @@ def test_call_dest_cpu() -> None: runnable.call_dest(x=input_tensor, output=output_tensor) # Verify the output tensor was filled correctly - expected = input_tensor * 2.0 + expected = input_tensor + 1.0 torch.testing.assert_close(output_tensor, expected, rtol=1e-5, atol=1e-5) @@ -367,7 +399,7 @@ def test_invalid_entry_point() -> None: description="Invalid entry point test", ) - builder = TVMFFIBuilder() + builder = TvmFfiBuilder() with pytest.raises(BuildError): builder.build(definition, invalid_solution) @@ -399,10 +431,51 @@ def test_no_sources() -> None: description="No sources test", ) - builder = TVMFFIBuilder() - with pytest.raises(BuildError, match="No sources"): + builder = TvmFfiBuilder() + with pytest.raises(BuildError, match="No CUDA or C\\+\\+ sources"): builder.build(definition, no_sources_solution) +def test_source_in_subdirectory() -> None: + """Test that source files in subdirectories are handled correctly.""" + n = 5 + definition = Definition( + name="test_subdirectory", + op_type="test", + description="Test source in subdirectory", + axes={"n": {"type": "const", "value": n}}, + constraints=[], + inputs={"x": {"shape": ["n"], "dtype": "float32"}}, + outputs={"output": {"shape": ["n"], "dtype": "float32"}}, + reference="def run(x): return x + 1", + ) + + # Place kernel in a subdirectory + solution = Solution( + name="test_subdirectory_impl", + definition="test_subdirectory", + author="test", + spec=BuildSpec( + language=SupportedLanguages.CUDA, + target_hardware=["cpu"], + entry_point="subdir/kernel.cpp::add_one_cpu", + ), + sources=[SourceFile(path="subdir/kernel.cpp", content=CPP_SOURCE)], + description="Test subdirectory handling", + ) + + # Build and run + builder = TvmFfiBuilder() + runnable = builder.build(definition, solution) + + # Test execution + input_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cpu", dtype=torch.float32) + output_tensor = runnable(x=input_tensor) + + # Verify result + expected = input_tensor + 1.0 + torch.testing.assert_close(output_tensor, expected, rtol=1e-5, atol=1e-5) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/data/test_solution.py b/tests/data/test_solution.py index ff8c3a7..8cf4152 100644 --- a/tests/data/test_solution.py +++ b/tests/data/test_solution.py @@ -14,9 +14,6 @@ def test_sourcefile_validation_python(): # Non-string content with pytest.raises(ValueError): SourceFile(path="main.py", content=123) # type: ignore[arg-type] - # Invalid python - with pytest.raises(ValueError): - SourceFile(path="main.py", content="def run(: pass") def test_buildspec_validation(): @@ -34,6 +31,12 @@ def test_buildspec_validation(): target_hardware=["cuda"], entry_point="main.py", # missing :: ) + with pytest.raises(ValueError): + BuildSpec( + language=SupportedLanguages.PYTHON, + target_hardware=["cuda"], + entry_point="main.py::run::add", # too many :: + ) # Invalid target_hardware list and dependencies types with pytest.raises(ValueError): BuildSpec( @@ -73,5 +76,43 @@ def test_solution_validation_and_helpers(): Solution(name="missing_entry", definition="def1", author="x", spec=spec, sources=[s2]) +def test_path_traversal_attack(): + """Test that path traversal attacks using '..' are blocked.""" + spec = BuildSpec( + language=SupportedLanguages.CUDA, + target_hardware=["cpu"], + entry_point="../../kernel.cpp::add_one_cpu", + ) + # Should fail at Solution creation time with path traversal error + with pytest.raises( + ValueError, match="Invalid source path \\(parent directory traversal not allowed\\)" + ): + Solution( + name="malicious", + definition="def1", + author="attacker", + spec=spec, + sources=[SourceFile(path="../../kernel.cpp", content="int main() {}")], + ) + + +def test_absolute_path_attack(): + """Test that absolute paths are blocked.""" + spec = BuildSpec( + language=SupportedLanguages.CUDA, + target_hardware=["cpu"], + entry_point="/tmp/kernel.cpp::add_one_cpu", + ) + # Should fail at Solution creation time with absolute path error + with pytest.raises(ValueError, match="absolute path not allowed"): + Solution( + name="malicious", + definition="def1", + author="attacker", + spec=spec, + sources=[SourceFile(path="/tmp/kernel.cpp", content="int main() {}")], + ) + + if __name__ == "__main__": pytest.main(sys.argv)