diff --git a/tests/test_sanitizer.py b/tests/test_sanitizer.py index dc184e97..f5e8f635 100644 --- a/tests/test_sanitizer.py +++ b/tests/test_sanitizer.py @@ -391,3 +391,56 @@ def test_atomic_cas(): atomic_cas_kernel[grid](y, cmp_value=0.0, new_value=5.0) # Note: The sanitizer analyzes symbolically, so the actual value may not be updated # This test verifies that the operation doesn't crash + + +# ======== Reduce Operations (max, min) ========= +def test_reduce_max_expr_eval(): + """Test that max reduce operation evaluates correctly.""" + # Test with array of values - create const SymbolicExpr with numpy array + input_arr = SymbolicExpr("const", np.array([1, 5, 3, 2]), tl.int32) + max_expr = SymbolicExpr("max", input_arr, None, False) + result, _ = max_expr.eval() + assert result == 5 + + +def test_reduce_min_expr_eval(): + """Test that min reduce operation evaluates correctly.""" + # Test with array of values - create const SymbolicExpr with numpy array + input_arr = SymbolicExpr("const", np.array([1, 5, 3, 2]), tl.int32) + min_expr = SymbolicExpr("min", input_arr, None, False) + result, _ = min_expr.eval() + assert result == 1 + + +def test_reduce_max_single_element(): + """Test that max reduce operation works with single element.""" + # Test with single element - create const SymbolicExpr with numpy array + input_arr = SymbolicExpr("const", np.array([42]), tl.int32) + max_expr = SymbolicExpr("max", input_arr, None, False) + result, _ = max_expr.eval() + assert result == 42 + + +def test_reduce_min_single_element(): + """Test that min reduce operation works with single element.""" + # Test with single element - create const SymbolicExpr with numpy array + input_arr = SymbolicExpr("const", np.array([42]), tl.int32) + min_expr = SymbolicExpr("min", input_arr, None, False) + result, _ = min_expr.eval() + assert result == 42 + + +def test_reduce_max_empty_array(): + """Test that max reduce operation raises ValueError for empty array.""" + input_arr = SymbolicExpr("const", np.array([]), tl.int32) + max_expr = SymbolicExpr("max", input_arr, None, False) + with pytest.raises(ValueError, match="Cannot compute max of empty array"): + max_expr.eval() + + +def test_reduce_min_empty_array(): + """Test that min reduce operation raises ValueError for empty array.""" + input_arr = SymbolicExpr("const", np.array([]), tl.int32) + min_expr = SymbolicExpr("min", input_arr, None, False) + with pytest.raises(ValueError, match="Cannot compute min of empty array"): + min_expr.eval() diff --git a/triton_viz/clients/sanitizer/sanitizer.py b/triton_viz/clients/sanitizer/sanitizer.py index 85291502..a0821245 100644 --- a/triton_viz/clients/sanitizer/sanitizer.py +++ b/triton_viz/clients/sanitizer/sanitizer.py @@ -2,7 +2,7 @@ from collections import namedtuple from collections.abc import Callable from dataclasses import dataclass, field -from functools import cached_property +from functools import cached_property, reduce from typing import Any, Optional, Union import re @@ -1182,10 +1182,16 @@ def _to_z3(self) -> tuple[ArithRef, list]: self._z3 = Sum(arr) if self.op == "max": - raise NotImplementedError("_to_z3 of max is not implemented yet") + arr, self._constraints = self.input._to_z3() + if not arr: + raise ValueError("Cannot compute max of empty array") + self._z3 = reduce(lambda a, b: If(a >= b, a, b), arr) if self.op == "min": - raise NotImplementedError("_to_z3 of min is not implemented yet") + arr, self._constraints = self.input._to_z3() + if not arr: + raise ValueError("Cannot compute min of empty array") + self._z3 = reduce(lambda a, b: If(a <= b, a, b), arr) if self.op == "load" or self.op == "store": # Load and store operations