Skip to content

Conversation

@wdziurdz
Copy link

New contributor declaration

  • [x ] I am not making a trivial change, such as fixing a typo in a comment.

  • [x ] I have written a PR description following these
    rules.

  • [ x] I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • [ x] I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • I have not added any lit tests.
    • [x ] The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

Description:
PoisonOpAxisInfoVisitor incorrectly returns rank = 1 for ub.poison operations that produce pointer-to-tensor types such as !tt.ptr<tensor<128x64xf16>>.
This incorrect rank then propagates through unrealized_conversion_cast operations generated during lowering, which leads to assertion failures in AxisInfo::join().
More details in issue #8823.
This PR also fixes #8823.

@Jokeren
Copy link
Contributor

Jokeren commented Nov 24, 2025

That's something I'm not sure we still want to support cc @ThomasRaoux

@ThomasRaoux
Copy link
Collaborator

In the current flow pointer of tensor are deprecated and they get lowered early in the TTIR flow, so I don't think we support in this pass and we shouldn't add code related to it.

@wdziurdz
Copy link
Author

@ThomasRaoux Could you clarify where exactly pointers “get lowered early in the TTIR flow”? I’m fixing the issue this way because UnrealizedConversionCastOpAxisInfoVisitor still handles pointers(ref to code below), and with this fix the same logic also appears in PoisonOpAxisInfoVisitor. This is mismatching.

class UnrealizedConversionCastOpAxisInfoVisitor final
: public AxisInfoVisitorImpl<mlir::UnrealizedConversionCastOp> {
public:
using AxisInfoVisitorImpl<
mlir::UnrealizedConversionCastOp>::AxisInfoVisitorImpl;
AxisInfo
getAxisInfo(mlir::UnrealizedConversionCastOp op,
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
auto tensorType = dyn_cast<RankedTensorType>(op.getResultTypes()[0]);
if (tensorType &&
tensorType.getRank() != operands[0]->getValue().getRank()) {
// Do not propagate AxisInfo with incorrect rank. This can cause a crash
// in future visitor applications.
return AxisInfo::getPessimisticValueState(op->getResult(0));
}
return operands[0]->getValue();
}
};

And this code handles pointers in the same way.
/*static*/ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
auto rank = 1;
if (TensorType ty = dyn_cast<TensorType>(value.getType()))
rank = ty.getRank();
if (triton::PointerType ty = dyn_cast<triton::PointerType>(value.getType()))
if (TensorType elemTy = dyn_cast<TensorType>(ty.getPointeeType()))
rank = elemTy.getRank();

@Jokeren
Copy link
Contributor

Jokeren commented Nov 25, 2025

And this code handles pointers in the same way.

I think some legacy code is just completely removed. Can you share with us the frontend code that leads to the IR you proposed to fix?

@ThomasRaoux
Copy link
Collaborator

@ThomasRaoux Could you clarify where exactly pointers “get lowered early in the TTIR flow”? I’m fixing the issue this way because UnrealizedConversionCastOpAxisInfoVisitor still handles pointers(ref to code below), and with this fix the same logic also appears in PoisonOpAxisInfoVisitor. This is mismatching.

class UnrealizedConversionCastOpAxisInfoVisitor final
: public AxisInfoVisitorImpl<mlir::UnrealizedConversionCastOp> {
public:
using AxisInfoVisitorImpl<
mlir::UnrealizedConversionCastOp>::AxisInfoVisitorImpl;
AxisInfo
getAxisInfo(mlir::UnrealizedConversionCastOp op,
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
auto tensorType = dyn_cast<RankedTensorType>(op.getResultTypes()[0]);
if (tensorType &&
tensorType.getRank() != operands[0]->getValue().getRank()) {
// Do not propagate AxisInfo with incorrect rank. This can cause a crash
// in future visitor applications.
return AxisInfo::getPessimisticValueState(op->getResult(0));
}
return operands[0]->getValue();
}
};

And this code handles pointers in the same way.

/*static*/ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
auto rank = 1;
if (TensorType ty = dyn_cast<TensorType>(value.getType()))
rank = ty.getRank();
if (triton::PointerType ty = dyn_cast<triton::PointerType>(value.getType()))
if (TensorType elemTy = dyn_cast<TensorType>(ty.getPointeeType()))
rank = elemTy.getRank();

this pass should remove all the pointer ot tensors:
https://github.com/triton-lang/triton/blob/main/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp

@wdziurdz
Copy link
Author

@ThomasRaoux thank you for your help. I’m closing this PR now.

@wdziurdz wdziurdz closed this Nov 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] ub.poison with ptr<tensor> produces incorrect rank in AxisInfo analysis

3 participants