Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions tests/test_sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,56 @@ def test_atomic_cas():
atomic_cas_kernel[grid](y, cmp_value=0.0, new_value=5.0)
# Note: The sanitizer analyzes symbolically, so the actual value may not be updated
# This test verifies that the operation doesn't crash


# ======== Reduce Operations (max, min) =========
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the tests are not as good as the previous tests. Should be refactored

def test_reduce_max_expr_eval():
"""Test that max reduce operation evaluates correctly."""
# Test with array of values - create const SymbolicExpr with numpy array
input_arr = SymbolicExpr("const", np.array([1, 5, 3, 2]), tl.int32)
max_expr = SymbolicExpr("max", input_arr, None, False)
result, _ = max_expr.eval()
assert result == 5


def test_reduce_min_expr_eval():
"""Test that min reduce operation evaluates correctly."""
# Test with array of values - create const SymbolicExpr with numpy array
input_arr = SymbolicExpr("const", np.array([1, 5, 3, 2]), tl.int32)
min_expr = SymbolicExpr("min", input_arr, None, False)
result, _ = min_expr.eval()
assert result == 1


def test_reduce_max_single_element():
"""Test that max reduce operation works with single element."""
# Test with single element - create const SymbolicExpr with numpy array
input_arr = SymbolicExpr("const", np.array([42]), tl.int32)
max_expr = SymbolicExpr("max", input_arr, None, False)
result, _ = max_expr.eval()
assert result == 42


def test_reduce_min_single_element():
"""Test that min reduce operation works with single element."""
# Test with single element - create const SymbolicExpr with numpy array
input_arr = SymbolicExpr("const", np.array([42]), tl.int32)
min_expr = SymbolicExpr("min", input_arr, None, False)
result, _ = min_expr.eval()
assert result == 42


def test_reduce_max_empty_array():
"""Test that max reduce operation raises ValueError for empty array."""
input_arr = SymbolicExpr("const", np.array([]), tl.int32)
max_expr = SymbolicExpr("max", input_arr, None, False)
with pytest.raises(ValueError, match="Cannot compute max of empty array"):
max_expr.eval()


def test_reduce_min_empty_array():
"""Test that min reduce operation raises ValueError for empty array."""
input_arr = SymbolicExpr("const", np.array([]), tl.int32)
min_expr = SymbolicExpr("min", input_arr, None, False)
with pytest.raises(ValueError, match="Cannot compute min of empty array"):
min_expr.eval()
12 changes: 9 additions & 3 deletions triton_viz/clients/sanitizer/sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import namedtuple
from collections.abc import Callable
from dataclasses import dataclass, field
from functools import cached_property
from functools import cached_property, reduce
from typing import Any, Optional, Union
import re

Expand Down Expand Up @@ -1182,10 +1182,16 @@ def _to_z3(self) -> tuple[ArithRef, list]:
self._z3 = Sum(arr)

if self.op == "max":
raise NotImplementedError("_to_z3 of max is not implemented yet")
arr, self._constraints = self.input._to_z3()
if not arr:
raise ValueError("Cannot compute max of empty array")
self._z3 = reduce(lambda a, b: If(a >= b, a, b), arr)

if self.op == "min":
raise NotImplementedError("_to_z3 of min is not implemented yet")
arr, self._constraints = self.input._to_z3()
if not arr:
raise ValueError("Cannot compute min of empty array")
self._z3 = reduce(lambda a, b: If(a <= b, a, b), arr)
Comment on lines 1184 to +1194

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Handle scalar inputs in max/min reduction

The new implementation assumes self.input._to_z3() returns an iterable, but many expressions (e.g., a reduction over a scalar tensor or the result of a symbolic load) yield a single ArithRef. In those cases if not arr: raises TypeError: Symbolic expressions cannot be cast to bool and the subsequent reduce(...) call raises TypeError: 'ArithRef' object is not iterable, so max/min still crash. sum works for both lists and scalars because Sum handles either case, but the max/min branches now fail for valid scalar inputs. Consider wrapping non-iterables in a list or skipping the emptiness check unless the result is a sequence.

Useful? React with 👍 / 👎.


if self.op == "load" or self.op == "store":
# Load and store operations
Expand Down