Skip to content

Conversation

@latentCall145
Copy link
Collaborator

sorry for big diff :(

Adds the NKI frontend to Triton-Viz:

  • added a NKI interpreter to apply language-level patching (e.g. allow NKI kernels to be executed line-by-line with NumPy; modifies nki.py)
  • AST rewriting to convert NKI loads/stores to masked loads/stores (modifies nki_extract_slice.py)
  • Masked load/store implementation (modifies nki_masked_load.py)
  • added backend (one of "nki" or "triton") as arguments for patching (modifies patch.py, client.py)
  • added tracer callbacks for masked loads/stores (modifies data.py, tracer.py)
  • separate Triton and NKI Trace objects (modifies trace.py)
  • a lot of visualization stuff (credit: @gujialiang123)

Added Tests

  • nki-examples/: demo programs to try out the visualizer (currently, only nki-examples/matmul.py is supported)
  • tests/test_nki.py: basic NDArray slicing
  • tests/test_masked_load.py: make sure masked load/store works correctly

@@ -0,0 +1,58 @@
import os
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this for?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry this script was used for debugging. I`ll remove it

)

def pre_store_callback(ptr, value, mask, cache_modifier, eviction_policy):
def pre_store_callback(ptr, value, mask, *ignore_args, **ignore_kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need these args and kwargs here?

Copy link
Collaborator Author

@latentCall145 latentCall145 Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're no longer necessary (they were an artifact of when I tried to use that callback to handle masked loads/stores). But in general the *args/**kwargs are in the callbacks because:

  1. Triton-viz callbacks take in all passed arguments to the functions they wrap (i.e. adding a callback to fn(a, ..., z) means that the callback needs to take in args a through z).
  2. Masked loads/stores have a different function signature compared to triton loads/stores
  3. Thus, using the same callback for different patched operations would cause issues (this is important for the post_dot_callback, which both Triton and NKI uses)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I would prefer just name them as args and kwargs

Copy link
Member

@Jokeren Jokeren Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think the client manager needs to standardize ops across different backends before passing them to clients

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already do that? In patch.py, OPERATION_REGISTRY maps backends and standardized ops into interpreter methods to patch, such as

OPERATION_REGISTRY['triton']['original_ops'][Dot] # equals interpreter_builder.create_dot
OPERATION_REGISTRY['nki']['original_ops'][Dot] # equals nki_builder.matmul

When patching an op like interpreter_builder.create_dot or nki_builder.matmul, they're both converted to a PatchOp(op_type=Dot) object so it's standardized in that way

Do you want the triton-viz callbacks to not allow args/kwargs and the client manager is responsible for supplying the right args to the callbacks? Something conceptually like this:

# some_placeholder_client.py
def pre_load_callback(a, b, c): # no args/kwargs, correct arguments selected at client manager level
    # ...

#client.py
def triton_patched_load(a, b, c, ...): # tl.load parameters
    pre_load_callback(a, b, c)
    tl.load(a, b, c, ...)
    ...
def nki_patched_load(b, a, c, ...): # nl.load parameters, may be different values or in different order from tl.load parameters
    pre_load_callback(a, b, c)
    nl.load(b, a, c, ...)
    ...

if backend == 'triton':
    patched_load = triton_patched_load
else:
    patched_load = nki_patched_load

Copy link
Member

@Jokeren Jokeren Nov 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want the triton-viz callbacks to not allow args/kwargs and the client manager is responsible for supplying the right args to the callbacks?

Yes, the code can exist in client.py, patch.py, or another file but not the concrete client. Each client itself only sees "normalized" callback parameters (e.g., a, b, c) in your case

def _extract_user_frames() -> list[traceback.FrameSummary]:
stack: list[traceback.FrameSummary] = list(traceback.extract_stack())
# drop current frames (this function and callers)
stack = stack[:-2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like it's duplicate with def __post_init__(self): function in data.py



@dataclass
class MaskedLoad(Op):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I probably missed something, I think RawStore or RawLoad shouldn't contain a mask while MaskedStore or MaskedLoad should contain a mask. Also why do we have three different load/store ops?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this was pretty hacky, I wanted to differentiate when to use the triton load/masked load (and store) callbacks when patching:

elif op_type is Load:
return OpCallbacks(before_callback=pre_load_callback)
elif op_type is MaskedLoad:
return OpCallbacks(before_callback=pre_masked_load_callback)

I was thinking of providing the DSL as an input to this function to allow something like this:

 elif op_type is Load:
     if backend == 'triton':
          return OpCallbacks(before_callback=pre_load_callback) 
     elif backend == 'nki': 
         return OpCallbacks(before_callback=pre_masked_load_callback) 

but it felt a bit weird that we would have to deal with backend-specific logic at the triton-viz level rather than at the language-level patching/AST rewriting stage.

Hence I made empty Op classes for MaskedLoad and MaskedStore

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this was pretty hacky, I wanted to differentiate when to use the triton load/masked load (and store) callbacks when patching:

Cannot both ops be modeled as a masked load? When you have a non-masked load/store, the mask field = None?

Copy link
Collaborator Author

@latentCall145 latentCall145 Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In reality the Load/MaskedLoad class naming is a misnomer and more accurate names would probably be TritonLoad/NKILoad. Or PointerPlusOffsetLoad/NumpyIndexedLoad if you want to signal that these implementations can be reused for different DSLs.

So yes, Load/MaskedLoad are both technically masked loads but for different backend implementations.

Though, you could model the triton and NKI load in the same callback. But you'd still need to handle the different input arguments provided in Triton vs. NKI. Something like:

def pre_load_callback(array, offsets=None, mask=None):
    if offsets is None: # only true for triton tensors
         offsets = array.data - ptr.data_ptr()
    # rest of logic

The above solution seems like it would be a bit confusing to read though

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though, you could model the triton and NKI load in the same callback. But you'd still need to handle the different input arguments provided in Triton vs. NKI. Something like:

Like I mentioned before, we can handle the differences before delivering the data to the concrete analysis client. In that case, do we still need to separate TritonLoad and NKILoad? Probably not?

if "backend" == "nki":
  a, b, c = handle_nki_load
elif "backend" == "triton":
  a, b, c = handle_triton_load

pre_load_callback(a, b, c)

Copy link
Collaborator Author

@latentCall145 latentCall145 Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems tricky even with normalized callbacks. NKI and Triton loads have different arguments, so it'd be more like

if 'backend' == 'nki':
    array, keys, mask = handle_nki_load
elif 'backend' == 'triton':
    array_plus_offsets, mask = handle_triton_load

I haven't been able to find a perfect way to make both backends use the same args.

  • I can't extract array from array_plus_offsets by itself (the tracer does this by storing tensor pointers collected from arg callbacks but this isn't something we can do in a stateless function)
  • We could have handle_triton_load return None for the keys arg so the normalized args have the same length . Then we'd need to handle if keys is None in the client callback, but doing this means clients implicitly handle DSL-specific logic which I don't think we want? Though this is probably the easiest option to implement if we're ok with this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does "keys" mean here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keys is the actual slice itself, i.e. for x[arange(B), 3] the keys would be (arange(B), 3)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's kind of general. Let's keep it and make it optional[none].


t = op_data["global_tensor"]
# Normalize to numpy array for robust indexing
if hasattr(t, "cpu"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the data need to be normalized outside this file

@mark14wu
Copy link
Collaborator

mark14wu commented Nov 4, 2025

@latentCall145 I just resolved the conflicts. Please pull before committing.

@Jokeren
Copy link
Member

Jokeren commented Nov 6, 2025

Hi @mark14wu , can you also review the PR?

@mark14wu
Copy link
Collaborator

mark14wu commented Nov 7, 2025

Hi @mark14wu , can you also review the PR?

Yes. I have dropped some comments, but haven't finished reviewing all files.

@mark14wu
Copy link
Collaborator

mark14wu commented Nov 7, 2025

Some tests in CI is not working.

Traceback:
/opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/importlib/__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
tests/test_wrapper.py:8: in <module>
    from triton_viz.wrapper import SANITIZER_COMMAND, PROFILER_COMMAND
triton_viz/__init__.py:1: in <module>
    from .core import trace, clear, config
triton_viz/core/__init__.py:1: in <module>
    from .trace import trace, clear
triton_viz/core/trace.py:7: in <module>
    from ..clients import Sanitizer, Profiler, Tracer
triton_viz/clients/__init__.py:1: in <module>
    from .profiler.profiler import Profiler
triton_viz/clients/profiler/profiler.py:1: in <module>
    from ...core.client import Client
triton_viz/core/client.py:8: in <module>
    from .patch import (
triton_viz/core/patch.py:61: in <module>
    from triton_viz.core.nki import nki_builder
triton_viz/core/nki.py:3: in <module>
    import neuronxcc.nki.language as nl
E   ModuleNotFoundError: No module named 'neuronxcc'

Please fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants