Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 263 additions & 1 deletion src/kernels/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
from pathlib import Path

from huggingface_hub import create_repo, upload_folder, create_branch
from huggingface_hub import create_branch, create_repo, upload_folder

from kernels.compat import tomllib
from kernels.lockfile import KernelLock, get_kernel_locks
Expand All @@ -15,6 +15,139 @@

BUILD_VARIANT_REGEX = re.compile(r"^(torch\d+\d+|torch-)")

# Default templates for kernel project initialization
DEFAULT_FLAKE_NIX = """\
{
description = "Flake for %(kernel_name)s kernel";
inputs = {
kernel-builder.url = "github:huggingface/kernel-builder";
};
outputs = { self, kernel-builder }:
kernel-builder.lib.genFlakeOutputs {
path = ./.;
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
};
}
"""

DEFAULT_BUILD_TOML = """\
[general]
name = "%(kernel_name)s"
backends = ["cuda", "rocm", "metal", "xpu"]

[torch]
src = [
"torch-ext/torch_binding.cpp",
"torch-ext/torch_binding.h",
]

[kernel.activation_cuda]
backend = "cuda"
# cuda-capabilities = ["9.0", "10.0", "12.0"] # if not specified, all capabilities will be used
depends = ["torch"]
src = [
"%(kernel_name)s_cuda/kernel.cu",
"%(kernel_name)s_cuda/kernel.h"
]

[kernel.activation_rocm]
backend = "rocm"
# rocm-archs = ["gfx906", "gfx908", "gfx90a", "gfx940", "gfx941", "gfx942", "gfx1030", "gfx1100", "gfx1101"] # if not specified, all architectures will be used
depends = ["torch"]
src = [
"%(kernel_name)s_cuda/kernel.cu",
"%(kernel_name)s_cuda/kernel.h",
]

[kernel.activation_xpu]
backend = "xpu"
depends = ["torch"]
src = [
"%(kernel_name)s_xpu/kernel.cpp",
"%(kernel_name)s_xpu/kernel.hpp",
]

[kernel.activation_metal]
backend = "metal"
depends = ["torch"]
src = [
"%(kernel_name)s_metal/kernel.mm",
"%(kernel_name)s_metal/kernel.metal",
]
"""

DEFAULT_GITATTRIBUTES = """\
*.so filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
"""

DEFAULT_README = """\
# %(kernel_name)s

A custom kernel for PyTorch.

## Installation

```bash
pip install kernels
```

## Usage

```python
from kernels import get_kernel

kernel = get_kernel("%(repo_id)s")


## License

Apache-2.0
"""

DEFAULT_INIT_PY = """\
# %(kernel_name)s kernel
# This file exports the kernel's public API

import torch
from ._ops import ops

def exported_kernel_function(x: torch.Tensor) -> torch.Tensor:
return ops.kernel_function(x)

__all__ = ["exported_kernel_function"]
"""

DEFAULT_TORCH_BINDING_CPP = """\
#include <torch/library.h>

#include "registration.h"
#include "torch_binding.h"

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("kernel_function(Tensor input) -> (Tensor)");
#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
ops.impl("kernel_function", torch::kCUDA, &kernel_function);
#elif defined(METAL_KERNEL)
ops.impl("kernel_function", torch::kMPS, kernel_function);
#elif defined(XPU_KERNEL)
ops.impl("kernel_function", torch::kXPU, &kernel_function);
#endif
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

"""

DEFAULT_TORCH_BINDING_H = """\
#pragma once

#include <torch/torch.h>

torch::Tensor kernel_function(torch::Tensor input);
"""


def main():
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -113,6 +246,31 @@ def main():
)
)

# Add init subcommand parser
init_parser = subparsers.add_parser(
"init",
help="Initialize a new kernel project structure",
)
init_parser.add_argument(
"kernel_name",
type=str,
help="Name of the kernel (e.g., 'my-kernel')",
)
init_parser.add_argument(
"--output-dir",
type=Path,
default=None,
help="Output directory for the kernel project (defaults to current directory)",
)
init_parser.add_argument(
"--repo-id",
type=str,
default=None,
help="Repository ID for the kernel (e.g., 'kernels-community/my-kernel')",
)

init_parser.set_defaults(func=init_kernel_project)

args = parser.parse_args()
args.func(args)

Expand Down Expand Up @@ -237,3 +395,107 @@ def check_kernel(
repo_id=repo_id,
revision=revision,
)


def init_kernel_project(args):
"""Initialize a new kernel project with the standard structure."""
kernel_name = args.kernel_name
if "/" in kernel_name or "\\" in kernel_name:
raise ValueError(
"Kernel name cannot contain path separators, to specify an output directory use the --output-dir argument"
)
# Normalize kernel name (replace hyphens with underscores for Python compatibility)
kernel_name_normalized = kernel_name.replace("-", "_")

# Determine output directory
if args.output_dir is not None:
output_dir = (Path(args.output_dir) / kernel_name).resolve()
else:
output_dir = Path.cwd() / kernel_name

# Determine repo_id
repo_id = args.repo_id if args.repo_id else f"your-username/{kernel_name}"

# Check if directory already exists
if output_dir.exists() and any(output_dir.iterdir()):
print(
f"Error: Directory '{output_dir}' already exists and is not empty.",
file=sys.stderr,
)
sys.exit(1)

# Create directory structure
dirs_to_create = [
output_dir,
output_dir / f"{kernel_name_normalized}_cuda",
output_dir / f"{kernel_name_normalized}_rocm",
output_dir / f"{kernel_name_normalized}_metal",
output_dir / f"{kernel_name_normalized}_xpu",
output_dir / "torch-ext",
output_dir / "torch-ext" / kernel_name_normalized,
]

for dir_path in dirs_to_create:
dir_path.mkdir(parents=True, exist_ok=True)

# Template substitution values
template_values = {
"kernel_name": kernel_name_normalized,
"repo_id": repo_id,
}

# Create files
files_to_create = {
"flake.nix": DEFAULT_FLAKE_NIX % template_values,
"build.toml": DEFAULT_BUILD_TOML % template_values,
".gitattributes": DEFAULT_GITATTRIBUTES,
"README.md": DEFAULT_README % template_values,
f"torch-ext/{kernel_name_normalized}/__init__.py": DEFAULT_INIT_PY
% template_values,
"torch-ext/torch_binding.cpp": DEFAULT_TORCH_BINDING_CPP % template_values,
"torch-ext/torch_binding.h": DEFAULT_TORCH_BINDING_H % template_values,
}

for file_path, content in files_to_create.items():
full_path = output_dir / file_path
full_path.parent.mkdir(parents=True, exist_ok=True)
with open(full_path, "w") as f:
f.write(content)

# Print success message
print(
f"✅ Kernel project '{kernel_name}' initialized successfully at: {output_dir}"
)
print()
print("Project structure:")
_print_tree(output_dir, prefix="")
print()
print("Next steps:")
print(f" 1. cd {output_dir}")
print(f" 2. Add your kernel implementation in {kernel_name_normalized}/")
print(
" 3. Update torch-ext/{kernel_name}/__init__.py to export your functions".format(
kernel_name=kernel_name_normalized
)
)
print(" 4. Build with: nix run .#build-and-copy ")
print(f" 5. Upload with: kernels upload . --repo-id {repo_id}")


def _print_tree(directory: Path, prefix: str = ""):
"""Print a directory tree structure."""
entries = sorted(directory.iterdir(), key=lambda x: (x.is_file(), x.name))
entries = [
e
for e in entries
if not e.name.startswith(".git") or e.name == ".gitattributes"
]

for i, entry in enumerate(entries):
is_last = i == len(entries) - 1
connector = "└── " if is_last else "├── "
print(f"{prefix}{connector}{entry.name}")

if entry.is_dir():
extension = " " if is_last else "│ "
_print_tree(entry, prefix + extension)
Loading