Skip to content

Commit 8ee2566

Browse files
committed
Preserve constexprs returned from functions
It is very easy to convert a constexpr into a tensor but difficult to go in the reverse direction. The current behavior makes it very difficult to support certain patterns without ceremony. While technically a compatibility break, I expect this to be rarely hit because many operations (like assignment) decay constexprs to tensors anyway.
1 parent 49e174c commit 8ee2566

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

python/test/unit/language/test_frontend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,3 +652,18 @@ def max_kernel(a: tl.constexpr, b: tl.constexpr):
652652

653653
with pytest.raises(CompilationError):
654654
run_parser(max_kernel, args=(1.0, -0.0))
655+
656+
657+
@pytest.mark.interpreter
658+
def test_constexpr_return():
659+
660+
@triton.jit
661+
def get_constexpr_value():
662+
return tl.constexpr(42)
663+
664+
@triton.jit
665+
def test():
666+
x: tl.constexpr = get_constexpr_value()
667+
tl.static_assert(x == 42)
668+
669+
test[(1, )]()

python/triton/compiler/code_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ def visit_Return(self, node):
550550
def decay(value):
551551
if isinstance(value, language.tuple):
552552
return _apply_to_tuple_values(value, decay)
553-
elif isinstance(value, (language.constexpr, int, float)):
553+
elif isinstance(value, (int, float)):
554554
return self.semantic.to_tensor(value)
555555
return value
556556

0 commit comments

Comments
 (0)