diff --git a/src/xla/Client.jl b/src/xla/Client.jl index ccf715c1ba..3d0e3dd117 100644 --- a/src/xla/Client.jl +++ b/src/xla/Client.jl @@ -13,4 +13,14 @@ function get_device end function get_addressable_device end function platform_name end -default_device(client::AbstractClient) = first(addressable_devices(client)) +""" + DEFAULT_DEVICE :: Ref{Int} + +0-based index of default device to use. +By default, the value of the environment variable `REACTANT_DEFAULT_DEVICE` is used when set to a non-negative integer, otherwise it is set to 0 (first available device). +""" +const DEFAULT_DEVICE = Ref{Int}(0) + +function default_device(client::AbstractClient) + return addressable_devices(client)[DEFAULT_DEVICE[] + 1] +end diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index 5fa34457b1..31335f2803 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -158,6 +158,13 @@ function __init__() 1 end + if haskey(ENV, "REACTANT_DEFAULT_DEVICE") + DEFAULT_DEVICE[] = max( + 0, something(tryparse(Int, ENV["REACTANT_DEFAULT_DEVICE"]), 0) + ) + @debug "REACTANT_DEFAULT_DEVICE: " DEFAULT_DEVICE[] maxlog = 1 + end + @debug "REACTANT_XLA_RUNTIME: " REACTANT_XLA_RUNTIME maxlog = 1 @ccall MLIR.API.mlir_c.RegisterEnzymeXLACPUHandler()::Cvoid