Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions triton_viz/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ def reset(self) -> None:
os.getenv("REPORT_GRID_EXECUTION_PROGRESS", "0") == "1"
) # verify using setter

# --- Virtual memory flag ---
self._virtual_memory = os.getenv("TRITON_VIZ_VIRTUAL_MEMORY", "0") == "1"

@property
def virtual_memory(self) -> bool:
return self._virtual_memory

# ---------- disable_sanitizer ----------
@property
def disable_sanitizer(self) -> bool:
Expand Down
75 changes: 73 additions & 2 deletions triton_viz/core/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from tqdm import tqdm

from .config import config as cfg
from dataclasses import dataclass
from .callbacks import OpCallbacks, ForLoopCallbacks

from .data import (
Op,
RawLoad,
Expand Down Expand Up @@ -53,6 +55,8 @@
)
from triton.runtime.interpreter import _patch_lang as triton_patch_lang
from triton.runtime.interpreter import ASTTransformer as _OrigASTTransformer
from triton.runtime.interpreter import _tuple_create, _unwrap_tensor, _rewrap_tensor
from triton.tools.tensor_descriptor import TensorDescriptor
from triton.runtime import JITFunction

op_list = [
Expand Down Expand Up @@ -476,6 +480,69 @@ def unpatch_lang():
importlib.reload(tl)


@dataclass(frozen=True)
class FakeTensor:
_data_ptr: int
dtype: str
shape: tuple[int, ...] = ()
_stride: tuple[int, ...] = ()
_is_contiguous: bool = True
_element_size: int = 1

def data_ptr(self) -> int:
return self._data_ptr

def stride(self) -> tuple[int, ...]:
return self._stride

def is_contiguous(self) -> bool:
return self._is_contiguous

def numel(self) -> int:
size = 1
for dim in self.shape:
size *= dim
return size

def element_size(self) -> int:
return self._element_size


def _init_args_hst(args_dev, kwargs):
def _to_cpu(arg):
if isinstance(arg, tuple):
return _tuple_create(arg, map(_to_cpu, arg))
elif isinstance(arg, TensorDescriptor):
return TensorDescriptor(
_to_cpu(arg.base),
arg.shape,
arg.strides,
arg.block_shape,
)
elif not hasattr(arg, "data_ptr"):
return arg

unwrapped_arg = _unwrap_tensor(arg)
cpu_arg = FakeTensor(
_data_ptr=unwrapped_arg.data_ptr(),
dtype=unwrapped_arg.dtype,
shape=unwrapped_arg.shape,
_stride=unwrapped_arg.stride(),
_is_contiguous=unwrapped_arg.is_contiguous(),
_element_size=unwrapped_arg.element_size(),
)
cpu_arg = _rewrap_tensor(cpu_arg, original_tensor=arg)
return cpu_arg

args_hst = [_to_cpu(arg) for arg in args_dev]

# Process keyword arguments
kwargs_hst = {}
for key, value in kwargs.items():
kwargs_hst[key] = _to_cpu(value)
return args_hst, kwargs_hst


def _grid_executor_call(self, *args_dev, **kwargs):
if kwargs.pop("warmup", False):
return
Expand Down Expand Up @@ -517,7 +584,10 @@ def run_grid_loops(grid):
}
client_manager = kwargs.pop("client_manager")
kwargs.pop("jit_fn")
args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs)
if cfg.virtual_memory:
args_hst, kwargs_hst = _init_args_hst(args_dev, kwargs)
else:
args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs)
# Prepare call arguments
args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst)
call_args = {}
Expand Down Expand Up @@ -547,7 +617,8 @@ def run_grid_loops(grid):
name = self.fn.__name__
print(f"Triton-Viz: execution time for {name}: {elapsed_time * 1000:.3f} ms")
# Copy arguments back to propagate side-effects
self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst)
if not cfg.virtual_memory:
self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst)


def _jit_function_call(self, *args, **kwargs):
Expand Down