Skip to content

Conversation

@avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Nov 15, 2025

once this lands we should be in a decent place to target LinearSolve.jl

  • cross
    • emit mlir
    • testing
  • svd
    • emit mlir
    • svdvals
    • ldiv
    • testing
  • cholesky
    • emit mlir
    • ldiv
    • testing
  • qr
    • emit mlir
    • ldiv
    • testing
  • factorize
    • emit mlir
    • ldiv
    • testing

@avik-pal avik-pal mentioned this pull request Nov 15, 2025
3 tasks
@avik-pal
Copy link
Collaborator Author

julia> @code_hlo svd(x_ra; full=true)
module @reactant_svd attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<8x8xf32> {enzymexla.memory_effects = []}) -> (tensor<8x8xf32>, tensor<8xf32>, tensor<8x8xf32>) attributes {enzymexla.memory_effects = []} {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<8x8xf32>) -> tensor<8x8xf32>
    %1:5 = stablehlo.custom_call @cusolver_gesvd_ffi(%0) {api_version = 4 : i32, compute_uv = true, enzymexla.guaranteed_symmetric = false, full_matrices = true, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], transposed = false} : (tensor<8x8xf32>) -> (tensor<8x8xf32>, tensor<8xf32>, tensor<8x8xf32>, tensor<8x8xf32>, tensor<i64>)
    %2 = stablehlo.transpose %1#2, dims = [1, 0] : (tensor<8x8xf32>) -> tensor<8x8xf32>
    %3 = stablehlo.transpose %1#3, dims = [1, 0] : (tensor<8x8xf32>) -> tensor<8x8xf32>
    return %2, %1#1, %3 : tensor<8x8xf32>, tensor<8xf32>, tensor<8x8xf32>
  }
}

julia> @jit svd(x_ra; full=true)
E0000 00:00:1763244269.299329  239787 pjrt_stream_executor_client.cc:2085] Execution of replica 0 failed: INVALID_ARGUMENT: Wrong number of attributes: expected 3 but got 0
ERROR: INVALID_ARGUMENT: Wrong number of attributes: expected 3 but got 0

Stacktrace:
 [1] reactant_err(msg::Cstring)
   @ Reactant.XLA /mnt/software/lux/Reactant.jl/src/xla/Utils.jl:12
 [2] execute
   @ /mnt/software/lux/Reactant.jl/src/xla/IFRT/LoadedExecutable.jl:126 [inlined]
 [3] execute_sharded
   @ /mnt/software/lux/Reactant.jl/src/xla/IFRT/LoadedExecutable.jl:166 [inlined]
 [4] macro expansion
   @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:3308 [inlined]
 [5] (::Reactant.Compiler.Thunk{typeof(svd), Symbol("##svd_reactant#563"), false, Tuple{…}, Reactant.XLA.IFRT.LoadedExecutable, Reactant.XLA.IFRT.Device, Reactant.XLA.IFRT.Client, Tuple{}, Vector{…}})(args::ConcreteIFRTArray{Float32, 2, Nothing})
   @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:3790
 [6] top-level scope
   @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:2597
 [7] top-level scope
   @ none:1
Some type information was truncated. Use `show(err)` to see complete types.

@avik-pal avik-pal force-pushed the ap/more_linalg_coverage branch from 27e3df1 to 9cc4e6a Compare November 16, 2025 15:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants