Commit c7aba73
[JAX] add support for gather/scatter batching dims following the new attributes in stablehlo.
This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.
See openxla/stablehlo#2259
PiperOrigin-RevId: 6476478251 parent db3f7d1 commit c7aba73
1 file changed
+12
-4
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
68 | 68 | | |
69 | 69 | | |
70 | 70 | | |
71 | | - | |
72 | | - | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
73 | 77 | | |
74 | 78 | | |
75 | 79 | | |
76 | 80 | | |
77 | 81 | | |
78 | 82 | | |
79 | | - | |
80 | | - | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
81 | 89 | | |
82 | 90 | | |
83 | 91 | | |
| |||
0 commit comments