-
Notifications
You must be signed in to change notification settings - Fork 21
[FEATURE] NKI interpreter #206
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
scripts/print_tracebacks.py
Outdated
| @@ -0,0 +1,58 @@ | |||
| import os | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry this script was used for debugging. I`ll remove it
triton_viz/clients/tracer/tracer.py
Outdated
| ) | ||
|
|
||
| def pre_store_callback(ptr, value, mask, cache_modifier, eviction_policy): | ||
| def pre_store_callback(ptr, value, mask, *ignore_args, **ignore_kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need these args and kwargs here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They're no longer necessary (they were an artifact of when I tried to use that callback to handle masked loads/stores). But in general the *args/**kwargs are in the callbacks because:
- Triton-viz callbacks take in all passed arguments to the functions they wrap (i.e. adding a callback to
fn(a, ..., z)means that the callback needs to take in argsathroughz). - Masked loads/stores have a different function signature compared to triton loads/stores
- Thus, using the same callback for different patched operations would cause issues (this is important for the
post_dot_callback, which both Triton and NKI uses)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I would prefer just name them as args and kwargs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also think the client manager needs to standardize ops across different backends before passing them to clients
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already do that? In patch.py, OPERATION_REGISTRY maps backends and standardized ops into interpreter methods to patch, such as
OPERATION_REGISTRY['triton']['original_ops'][Dot] # equals interpreter_builder.create_dot
OPERATION_REGISTRY['nki']['original_ops'][Dot] # equals nki_builder.matmulWhen patching an op like interpreter_builder.create_dot or nki_builder.matmul, they're both converted to a PatchOp(op_type=Dot) object so it's standardized in that way
Do you want the triton-viz callbacks to not allow args/kwargs and the client manager is responsible for supplying the right args to the callbacks? Something conceptually like this:
# some_placeholder_client.py
def pre_load_callback(a, b, c): # no args/kwargs, correct arguments selected at client manager level
# ...
#client.py
def triton_patched_load(a, b, c, ...): # tl.load parameters
pre_load_callback(a, b, c)
tl.load(a, b, c, ...)
...
def nki_patched_load(b, a, c, ...): # nl.load parameters, may be different values or in different order from tl.load parameters
pre_load_callback(a, b, c)
nl.load(b, a, c, ...)
...
if backend == 'triton':
patched_load = triton_patched_load
else:
patched_load = nki_patched_loadThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want the triton-viz callbacks to not allow args/kwargs and the client manager is responsible for supplying the right args to the callbacks?
Yes, the code can exist in client.py, patch.py, or another file but not the concrete client. Each client itself only sees "normalized" callback parameters (e.g., a, b, c) in your case
| def _extract_user_frames() -> list[traceback.FrameSummary]: | ||
| stack: list[traceback.FrameSummary] = list(traceback.extract_stack()) | ||
| # drop current frames (this function and callers) | ||
| stack = stack[:-2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like it's duplicate with def __post_init__(self): function in data.py
triton_viz/core/data.py
Outdated
|
|
||
|
|
||
| @dataclass | ||
| class MaskedLoad(Op): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I probably missed something, I think RawStore or RawLoad shouldn't contain a mask while MaskedStore or MaskedLoad should contain a mask. Also why do we have three different load/store ops?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this was pretty hacky, I wanted to differentiate when to use the triton load/masked load (and store) callbacks when patching:
triton-viz/triton_viz/clients/tracer/tracer.py
Lines 245 to 248 in 1910b78
| elif op_type is Load: | |
| return OpCallbacks(before_callback=pre_load_callback) | |
| elif op_type is MaskedLoad: | |
| return OpCallbacks(before_callback=pre_masked_load_callback) |
I was thinking of providing the DSL as an input to this function to allow something like this:
elif op_type is Load:
if backend == 'triton':
return OpCallbacks(before_callback=pre_load_callback)
elif backend == 'nki':
return OpCallbacks(before_callback=pre_masked_load_callback) but it felt a bit weird that we would have to deal with backend-specific logic at the triton-viz level rather than at the language-level patching/AST rewriting stage.
Hence I made empty Op classes for MaskedLoad and MaskedStore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this was pretty hacky, I wanted to differentiate when to use the triton load/masked load (and store) callbacks when patching:
Cannot both ops be modeled as a masked load? When you have a non-masked load/store, the mask field = None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In reality the Load/MaskedLoad class naming is a misnomer and more accurate names would probably be TritonLoad/NKILoad. Or PointerPlusOffsetLoad/NumpyIndexedLoad if you want to signal that these implementations can be reused for different DSLs.
So yes, Load/MaskedLoad are both technically masked loads but for different backend implementations.
Though, you could model the triton and NKI load in the same callback. But you'd still need to handle the different input arguments provided in Triton vs. NKI. Something like:
def pre_load_callback(array, offsets=None, mask=None):
if offsets is None: # only true for triton tensors
offsets = array.data - ptr.data_ptr()
# rest of logicThe above solution seems like it would be a bit confusing to read though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though, you could model the triton and NKI load in the same callback. But you'd still need to handle the different input arguments provided in Triton vs. NKI. Something like:
Like I mentioned before, we can handle the differences before delivering the data to the concrete analysis client. In that case, do we still need to separate TritonLoad and NKILoad? Probably not?
if "backend" == "nki":
a, b, c = handle_nki_load
elif "backend" == "triton":
a, b, c = handle_triton_load
pre_load_callback(a, b, c)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems tricky even with normalized callbacks. NKI and Triton loads have different arguments, so it'd be more like
if 'backend' == 'nki':
array, keys, mask = handle_nki_load
elif 'backend' == 'triton':
array_plus_offsets, mask = handle_triton_loadI haven't been able to find a perfect way to make both backends use the same args.
- I can't extract
arrayfromarray_plus_offsetsby itself (the tracer does this by storing tensor pointers collected from arg callbacks but this isn't something we can do in a stateless function) - We could have
handle_triton_loadreturnNonefor thekeysarg so the normalized args have the same length . Then we'd need to handle if keys is None in the client callback, but doing this means clients implicitly handle DSL-specific logic which I don't think we want? Though this is probably the easiest option to implement if we're ok with this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does "keys" mean here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keys is the actual slice itself, i.e. for x[arange(B), 3] the keys would be (arange(B), 3)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's kind of general. Let's keep it and make it optional[none].
triton_viz/visualizer/interface.py
Outdated
|
|
||
| t = op_data["global_tensor"] | ||
| # Normalize to numpy array for robust indexing | ||
| if hasattr(t, "cpu"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the data need to be normalized outside this file
|
@latentCall145 I just resolved the conflicts. Please pull before committing. |
…ck args for all clients)
…ols/triton-viz into nki/patch-dsl
|
Hi @mark14wu , can you also review the PR? |
Yes. I have dropped some comments, but haven't finished reviewing all files. |
|
Some tests in CI is not working. Please fix. |
sorry for big diff :(
Adds the NKI frontend to Triton-Viz:
nki.py)nki_extract_slice.py)nki_masked_load.py)backend(one of "nki" or "triton") as arguments for patching (modifiespatch.py,client.py)data.py,tracer.py)trace.py)Added Tests
nki-examples/: demo programs to try out the visualizer (currently, onlynki-examples/matmul.pyis supported)tests/test_nki.py: basic NDArray slicingtests/test_masked_load.py: make sure masked load/store works correctly