-
Notifications
You must be signed in to change notification settings - Fork 21
[DEV][Sanitizer]Implement max and min reduce operations in _to_z3 #208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[DEV][Sanitizer]Implement max and min reduce operations in _to_z3 #208
Conversation
Co-authored-by: Jokeren <[email protected]>
Co-authored-by: Jokeren <[email protected]>
Co-authored-by: Jokeren <[email protected]>
max in _to_z3…e-4c8b-87e3-dc9ecbed582f
…e-4c8b-87e3-dc9ecbed582f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if self.op == "max": | ||
| raise NotImplementedError("_to_z3 of max is not implemented yet") | ||
| arr, self._constraints = self.input._to_z3() | ||
| if not arr: | ||
| raise ValueError("Cannot compute max of empty array") | ||
| self._z3 = reduce(lambda a, b: If(a >= b, a, b), arr) | ||
|
|
||
| if self.op == "min": | ||
| raise NotImplementedError("_to_z3 of min is not implemented yet") | ||
| arr, self._constraints = self.input._to_z3() | ||
| if not arr: | ||
| raise ValueError("Cannot compute min of empty array") | ||
| self._z3 = reduce(lambda a, b: If(a <= b, a, b), arr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle scalar inputs in max/min reduction
The new implementation assumes self.input._to_z3() returns an iterable, but many expressions (e.g., a reduction over a scalar tensor or the result of a symbolic load) yield a single ArithRef. In those cases if not arr: raises TypeError: Symbolic expressions cannot be cast to bool and the subsequent reduce(...) call raises TypeError: 'ArithRef' object is not iterable, so max/min still crash. sum works for both lists and scalars because Sum handles either case, but the max/min branches now fail for valid scalar inputs. Consider wrapping non-iterables in a list or skipping the emptiness check unless the result is a sequence.
Useful? React with 👍 / 👎.
| # This test verifies that the operation doesn't crash | ||
|
|
||
|
|
||
| # ======== Reduce Operations (max, min) ========= |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the tests are not as good as the previous tests. Should be refactored
The sanitizer's symbolic execution engine was raising
NotImplementedErrorwhen encounteringmaxorminreduce operations, preventing analysis of kernels using these operations.Changes
triton_viz/clients/sanitizer/sanitizer.pyreduceto functools importsmaxoperation:reduce(lambda a, b: If(a >= b, a, b), arr)minoperation:reduce(lambda a, b: If(a <= b, a, b), arr)tests/test_sanitizer.pyImplementation
The implementation follows the existing
sumpattern, using z3'sIffor element-wise comparison:This resolves the traceback where
_to_z3encounters max/min operations in kernel memory access patterns.Original prompt
This section details on the original issue you should resolve
<issue_title>[FEATURE] Support
maxin_to_z3</issue_title><issue_description>```
Traceback (most recent call last):
File "/home/hwu27/workspace/triton-viz/.venv/bin/triton-sanitizer", line 10, in
sys.exit(apply())
^^^^^^^
File "/home/hwu27/workspace/triton-viz/triton_viz/wrapper.py", line 58, in apply
runpy.run_path(script, run_name="main")
File "", line 286, in run_path
File "", line 98, in _run_module_code
File "", line 88, in _run_code
File "cache_transform_triton.py", line 165, in
result_gold = test_get_xine_cache()
^^^^^^^^^^^^^^^^^^^^^
File "cache_transform_triton.py", line 160, in test_get_xine_cache
cos_output, sin_output = get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=is_prompts)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "cache_transform_triton.py", line 102, in get_xine_cache
prefill_cache_kernel[grid](
File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/runtime/jit.py", line 390, in
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/triton_viz/core/trace.py", line 68, in run
ret = self.interpreter_fn.run(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/runtime/interpreter.py", line 1380, in run
return GridExecutor(fn, self.arg_names, grid)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/triton_viz/core/patch.py", line 488, in _grid_executor_call
run_grid_loops(grid)
File "/home/hwu27/workspace/triton-viz/triton_viz/core/patch.py", line 456, in run_grid_loops
self.fn(**call_args)
File "cache_transform_triton.py", line 27, in prefill_cache_kernel
cos_cache_part = tl.load(
^^^^^^^^
File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/runtime/interpreter.py", line 781, in
new_member = lambda *args, member=member, **kwargs: (member(*args, **
^^^^^^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/language/core.py", line 42, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/language/core.py", line 2150, in load
return _semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/language/semantic.py", line 1086, in load
return self._load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/language/semantic.py", line 1068, in _load_legacy
self.builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache,
File "/home/hwu27/workspace/triton-viz/triton_viz/core/patch.py", line 215, in
lambda *args, **kwargs: patched_op(*args, **kwargs),
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/triton_viz/core/patch.py", line 187, in call
ret = self.callbacks.op_overrider(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/triton_viz/clients/sanitizer/sanitizer.py", line 1671, in op_load_overrider
self._handle_access_check(ret)
File "/home/hwu27/workspace/triton-viz/triton_viz/clients/sanitizer/sanitizer.py", line 1521, in _handle_access_check
z3_addr, z3_constraints = expr.eval()
^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/triton_viz/clients/sanitizer/sanitizer.py", line 1050, in eval
expr, constraints = self._to_z3()
^^^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/triton_viz/clients/sanitizer/sanitizer.py", line 1164, in _to_z3
ptr, constraints_ptr = self.ptr._to_z3()
^^^^^^^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/triton_viz/clients/sanitizer/sanitizer.py", line 1176, in _to_z3
ptr_z3, constraints_ptr = self.ptr._to_z3()
^^^^^^^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/triton_viz/clients/sanitiz...
maxin_to_z3#134✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.