Skip to content

Conversation

@keshavvinayak01
Copy link
Contributor

@keshavvinayak01 keshavvinayak01 commented Nov 4, 2025

Description

  • Added support for PyTorch's flex_attention Higher-Order Operator in torch-mlir.
  • Implemented Torch_AtenFlexAttentionOp with 6 operands (query, key, value, scale, enable_gqa, return_lse) and 2 optional attributes (score_mod_fn, mask_mod_fn) for function references.
  • The FX importer (_import_hop_flex_attention) correctly extracts score/mask modification functions from get_attr nodes using module IDs, following the while_loop HOP pattern.
  • Includes TODO markers for kernel_options performance tuning parameters.
  • Imports flex_attention from PyTorch FX graphs into valid MLIR.

keshavvinayak01 and others added 17 commits October 22, 2025 09:41
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Change 1: Converts builtin tensors → Torch tensors when entering the loop body
Change 2: Ensures Torch tensors → builtin tensors when yielding back to the loop condition
Without these fixes, the conversion would fail when while loops carry tensor values

Also modified basic_test.py FILECHECK statements.

Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1. Better documentation for AtenFlexAttentionOp
2. Function referece added as attributes to aten.flex_attention
3. Updates to _import_hop_flex_attention reflecting latest changes of module import.
4. Removed discardable attributes; scored_mod_fn and mask_mod_fn added as optionalAttr

Signed-off-by: Keshav Vinayak Jha <[email protected]>
Remove note about method usage for HOPs.
@keshavvinayak01 keshavvinayak01 changed the title Keshavvinayak01/torch aten flex attention [TORCH] Added flex_attention hop function Nov 4, 2025
Removed TODO note for grouped query attention support in the docstring and comments.
@keshavvinayak01 keshavvinayak01 force-pushed the keshavvinayak01/torch-aten-flex_attention branch from 095cb61 to 5e024f6 Compare November 6, 2025 09:36
@keshavvinayak01 keshavvinayak01 marked this pull request as ready for review November 6, 2025 09:37
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does enable importing to mlir.

However, the changes don't actually provide "support" for this op, since the torch op can neither be decomposed nor lowered to any other dialects.

Although we could review/merge this and subsequently add a lowering path for the op in MLIR, I would personally prefer the e2e support is added in the same PR as the import support.

This is a rather unique operator, so having passing e2e tests would give me a lot more confidence in the choices made here. Otherwise I'm basically just hoping that what you did generally makes sense (or doing a significant amount of work myself to check it out), because there really isn't much precedent for these kinds of choices in our codebase.

@Groverkss
Copy link
Member

Groverkss commented Nov 11, 2025

This does enable importing to mlir.

However, the changes don't actually provide "support" for this op, since the torch op can neither be decomposed nor lowered to any other dialects.

Although we could review/merge this and subsequently add a lowering path for the op in MLIR, I would personally prefer the e2e support is added in the same PR as the import support.

This is a rather unique operator, so having passing e2e tests would give me a lot more confidence in the choices made here. Otherwise I'm basically just hoping that what you did generally makes sense (or doing a significant amount of work myself to check it out), because there really isn't much precedent for these kinds of choices in our codebase.

The only thing needed to have this passing e2e tests is implementing TilingInterface for this operation:

LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,

With that said, it's an unreasonable bar to set that every operation must compile e2e through torch-mlir. Torch-MLIR is not a compiler, even though it has tests for e2e paths. The project docs explicitly call out this:

Torch-MLIR is primarily a project that is integrated into compilers to bridge them to PyTorch and ONNX. If contemplating a new integration, it may be helpful to refer to existing downstreams:

IREE
Blade
While most of the project is exercised via testing paths, there are some ways that an end user can directly use the APIs without further integration:

It should be okay to land support for ops through the importer without it running e2e tests in torch-mlir. I've looked at the implementation of e2e tests for more complex ops like attention, and they are not good implementations, they don't add much value.

We should as a project allow landing PRs that add support to the importer seperately from e2e tests (Atleast for HOPs). I don't think having a dummy implementation for an op should be the bar to land an operation.

@zjgarvey
Copy link
Collaborator

@Groverkss So this torch op lowers to a tm tensor op? Because I don't see where that is happening.

My blocking is primarily predicated on the fact that this op is imported to something completely unhandled. Even then, I'm happy to unblock and review as is, but it warranted discussion at least. If you would like to add a review yourself, your context on attention ops would be very helpful.

It's simply my preference that we have an e2e test, and I'm not going to block based on that alone.

@Groverkss
Copy link
Member

@Groverkss So this torch op lowers to a tm tensor op? Because I don't see where that is happening.

I think there is a PR running around in IREE that lowers this op to IREE's attention op (found it: iree-org/iree#22441).

I don't think TMTensor is really a requirement anymore, since you can directly lower a torch op in your own project. I think TMTensor is more of a thing of the past, when we really wanted torch-mlir to lower everything for us and we didn't hook patterns into it. For historical context on how TMTensor was used and how it was replaced in IREE (and generally how it should be used now): iree-org/iree#14917

My blocking is primarily predicated on the fact that this op is imported to something completely unhandled. Even then, I'm happy to unblock and review as is, but it warranted discussion at least. If you would like to add a review yourself, your context on attention ops would be very helpful.

I refrained from adding a review on this because I was guiding @keshavvinayak01 through the implementation and didn't want to land this without getting an extra pair of eyes 😅 I think your review on this is invaluable and I'll still let you decide if we should land this as is or not.

It's simply my preference that we have an e2e test, and I'm not going to block based on that alone.

My main worry is that we are tieing the fx_importer to the e2e tests. I personally believe that the e2e test lowering test suite and the fx_importer are seperate pieces of utlity and one should be able to use one without another. I do think the e2e tests are useful though, so I'll recommend @keshavvinayak01 to send a patch implementing TilingInterface for this operation just like we have for the TMTensor op. But that should be seperate from this patch.

@zjgarvey
Copy link
Collaborator

I think there is a PR running around in IREE that lowers this op to IREE's attention op (found it: iree-org/iree#22441).

I don't think TMTensor is really a requirement anymore, since you can directly lower a torch op in your own project. I think TMTensor is more of a thing of the past, when we really wanted torch-mlir to lower everything for us and we didn't hook patterns into it. For historical context on how TMTensor was used and how it was replaced in IREE (and generally how it should be used now): iree-org/iree#14917

Ah, these are both useful context. Thanks. Yeah, if we don't care about having some implementation here, I'm totally fine with that.

I refrained from adding a review on this because I was guiding @keshavvinayak01 through the implementation and didn't want to land this without getting an extra pair of eyes 😅 I think your review on this is invaluable and I'll still let you decide if we should land this as is or not.

That makes sense. I'll review now.

My main worry is that we are tieing the fx_importer to the e2e tests. I personally believe that the e2e test lowering test suite and the fx_importer are seperate pieces of utlity and one should be able to use one without another. I do think the e2e tests are useful though, so I'll recommend @keshavvinayak01 to send a patch implementing TilingInterface for this operation just like we have for the TMTensor op. But that should be seperate from this patch.

Yeah, that sounds good. I just wasn't aware that it was common practice to in-house certain torch lowerings in downstream projects like IREE.

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think two things would be nice before merging, but since none of the changes here really affect anything else in torch-mlir, I'm not going to block anything.

  1. An importer test would be incredibly valuable in my opinion. I'm half-inclined to write one myself just to debug-print the fx graph and mlir so I can review this PR a bit better.

  2. Some explanation of what the enable_gqa arg is doing/not doing. As you can see from my comments, I'm a bit confused by this arg, since it doesn't seem to do anything in pytorch or in the torch-mlir op (where it is hardcoded to False).

}

// CHECK-LABEL: func.func @torch.aten.flex_attention
func.func @torch.aten.flex_attention (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is a roundtrip parsing test or something?

This is good to have, but if we don't have any e2e tests, I would at least want an fx_importer lit test for this op. The reason being that I have no idea if the IR here is actually what the importer generates. And if pytorch bumps happen to break the import for this op, I want the CI to flag that.

You added one of these tests for the last HOP PR in this directory:

https://github.com/llvm/torch-mlir/tree/main/test/python/fx_importer

I'd be inclined to have a separate test file for various HOPs if basic_test.py is getting too busy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add it to basic_test.py, but it'll spit out an unverified graph module, to which I can add the corresponding FileCheck statements. I'm not sure we want to commit that to the test. Basically this:

"builtin.module"() ({
  "func.func"() <{function_type = (!torch.vtensor<[],f32>, !torch.vtensor<[],si32>, !torch.vtensor<[],si32>, !torch.vtensor<[],si32>, !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32>, sym_name = "sdpa_score0", sym_visibility = "private"}> ({
  ^bb0(%arg7: !torch.vtensor<[],f32>, %arg8: !torch.vtensor<[],si32>, %arg9: !torch.vtensor<[],si32>, %arg10: !torch.vtensor<[],si32>, %arg11: !torch.vtensor<[],si32>):
    %9 = "torch.aten.tanh"(%arg7) : (!torch.vtensor<[],f32>) -> !torch.vtensor<[],f32>
    "func.return"(%9) : (!torch.vtensor<[],f32>) -> ()
  }) : () -> ()
  "func.func"() <{function_type = (!torch.vtensor<[],si32>, !torch.vtensor<[],si32>, !torch.vtensor<[],si32>, !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1>, sym_name = "sdpa_mask0", sym_visibility = "private"}> ({
  ^bb0(%arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>, %arg5: !torch.vtensor<[],si32>, %arg6: !torch.vtensor<[],si32>):
    %3 = "torch.prim.ListConstruct"() : () -> !torch.list<int>
    %4 = "torch.constant.int"() <{value = 11 : i64}> : () -> !torch.int
    %5 = "torch.constant.none"() : () -> !torch.none
    %6 = "torch.constant.device"() <{value = "cpu"}> : () -> !torch.Device
    %7 = "torch.constant.bool"() <{value = false}> : () -> !torch.bool
    %8 = "torch.aten.new_ones"(%arg3, %3, %4, %5, %6, %7) : (!torch.vtensor<[],si32>, !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.bool) -> !torch.vtensor<[],i1>
    "func.return"(%8) : (!torch.vtensor<[],i1>) -> ()
  }) : () -> ()
  "func.func"() <{function_type = (!torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>), sym_name = "test_attention"}> ({
  ^bb0(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>):
    %0 = "torch.constant.float"() <{value = 1.000000e+00 : f64}> : () -> !torch.float
    %1 = "torch.constant.bool"() <{value = 0 : i0}> : () -> !torch.bool
    %2:2 = "torch.aten.flex_attention"(%arg0, %arg1, %arg2, %0, %1) <{mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0}> : (!torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool) -> (!torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>)
    "func.return"(%2#0, %2#1) : (!torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>) -> ()
  }) : () -> ()
}) : () -> ()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, that sounds like there is an issue with the importer logic. If the IR doesn't verify, something is wrong, no?

E.g., in some places you have

    %1 = "torch.constant.bool"() <{value = 0 : i0}> : () -> !torch.bool

and in others

    %7 = "torch.constant.bool"() <{value = false}> : () -> !torch.bool

Which one of these is correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed it, but still the same thing is spit out. It's because I can't find torch.aten.flex_attention in the registry.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, that also seems like a bug. Would you mind pushing the local test to the remote branch so I can see the error message in the CI? We need to add a test anyway, so it will be helpful if we can both look at it.

Signed-off-by: Keshav Vinayak Jha <[email protected]>
Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think removing the unused arg makes sense, thanks for doing that.

Based on the comments, this PR definitely needs to have at least one importer test, but I would highly recommend adding tests for both default and non-default mod functions.

}

// CHECK-LABEL: func.func @torch.aten.flex_attention
func.func @torch.aten.flex_attention (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, that sounds like there is an issue with the importer logic. If the IR doesn't verify, something is wrong, no?

E.g., in some places you have

    %1 = "torch.constant.bool"() <{value = 0 : i0}> : () -> !torch.bool

and in others

    %7 = "torch.constant.bool"() <{value = false}> : () -> !torch.bool

Which one of these is correct?

Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants