From 4cdb9eca13492538c6a203538088f8d05e2f55ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Thu, 13 Nov 2025 18:27:21 +0000 Subject: [PATCH 1/6] Use `REACTANT_DEFAULT_DEVICE` to set default device ID with env var --- src/xla/Client.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/xla/Client.jl b/src/xla/Client.jl index ccf715c1ba..ce517b6231 100644 --- a/src/xla/Client.jl +++ b/src/xla/Client.jl @@ -13,4 +13,8 @@ function get_device end function get_addressable_device end function platform_name end -default_device(client::AbstractClient) = first(addressable_devices(client)) +function default_device(client::AbstractClient) + return addressable_devices(client)[something( + tryparse(Int, get(ENV, "REACTANT_DEFAULT_DEVICE", "1")), 1 + )] +end From 8c186e27e78a04d25e7c776f137e6ac4a4e1b255 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Thu, 13 Nov 2025 18:47:06 +0000 Subject: [PATCH 2/6] Interpret `REACTANT_DEFAULT_DEVICE` as 0-based --- src/xla/Client.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/xla/Client.jl b/src/xla/Client.jl index ce517b6231..30233154ec 100644 --- a/src/xla/Client.jl +++ b/src/xla/Client.jl @@ -15,6 +15,8 @@ function platform_name end function default_device(client::AbstractClient) return addressable_devices(client)[something( - tryparse(Int, get(ENV, "REACTANT_DEFAULT_DEVICE", "1")), 1 + # `REACTANT_DEFAULT_DEVICE` is interpreted as 0-based. + tryparse(Int, get(ENV, "REACTANT_DEFAULT_DEVICE", "0")) + 1, + 1, )] end From 541087eceb0001d8ee1689060e8ac442b5a2980c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Thu, 13 Nov 2025 19:26:11 +0000 Subject: [PATCH 3/6] Use global `const` `Reactant.XLA.DEFAULT_DEVICE` instead of env var --- src/xla/Client.jl | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/xla/Client.jl b/src/xla/Client.jl index 30233154ec..31925311b2 100644 --- a/src/xla/Client.jl +++ b/src/xla/Client.jl @@ -13,10 +13,13 @@ function get_device end function get_addressable_device end function platform_name end +""" + DEFAULT_DEVICE :: Ref{Int} + +0-based index of default device to use, by default 0 (first available device). +""" +const DEFAULT_DEVICE = Ref{Int}(0) + function default_device(client::AbstractClient) - return addressable_devices(client)[something( - # `REACTANT_DEFAULT_DEVICE` is interpreted as 0-based. - tryparse(Int, get(ENV, "REACTANT_DEFAULT_DEVICE", "0")) + 1, - 1, - )] + return addressable_devices(client)[DEFAULT_DEVICE[] + 1] end From 3bb3be9eed4a196bd18d8e6196d22b24e741d224 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Wed, 3 Dec 2025 11:58:41 +0000 Subject: [PATCH 4/6] Make `DEFAULT_DEVICE` 1-based --- src/xla/Client.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/xla/Client.jl b/src/xla/Client.jl index 31925311b2..460fb57133 100644 --- a/src/xla/Client.jl +++ b/src/xla/Client.jl @@ -16,10 +16,10 @@ function platform_name end """ DEFAULT_DEVICE :: Ref{Int} -0-based index of default device to use, by default 0 (first available device). +1-based index of default device to use, by default 1 (first available device). """ -const DEFAULT_DEVICE = Ref{Int}(0) +const DEFAULT_DEVICE = Ref{Int}(1) function default_device(client::AbstractClient) - return addressable_devices(client)[DEFAULT_DEVICE[] + 1] + return addressable_devices(client)[DEFAULT_DEVICE[]] end From 26a143897fd7a34ec1e99bbc2fe65c3300f2c599 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Wed, 3 Dec 2025 12:01:14 +0000 Subject: [PATCH 5/6] Revert "Make `DEFAULT_DEVICE` 1-based" This reverts commit 3bb3be9eed4a196bd18d8e6196d22b24e741d224. --- src/xla/Client.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/xla/Client.jl b/src/xla/Client.jl index 460fb57133..31925311b2 100644 --- a/src/xla/Client.jl +++ b/src/xla/Client.jl @@ -16,10 +16,10 @@ function platform_name end """ DEFAULT_DEVICE :: Ref{Int} -1-based index of default device to use, by default 1 (first available device). +0-based index of default device to use, by default 0 (first available device). """ -const DEFAULT_DEVICE = Ref{Int}(1) +const DEFAULT_DEVICE = Ref{Int}(0) function default_device(client::AbstractClient) - return addressable_devices(client)[DEFAULT_DEVICE[]] + return addressable_devices(client)[DEFAULT_DEVICE[] + 1] end From 8ea159a738eb1564495fe0e9388c9a9a779de82a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Wed, 3 Dec 2025 12:12:52 +0000 Subject: [PATCH 6/6] Use the `REACTANT_DEFAULT_DEVICE` environment variable --- src/xla/Client.jl | 3 ++- src/xla/XLA.jl | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/xla/Client.jl b/src/xla/Client.jl index 31925311b2..3d0e3dd117 100644 --- a/src/xla/Client.jl +++ b/src/xla/Client.jl @@ -16,7 +16,8 @@ function platform_name end """ DEFAULT_DEVICE :: Ref{Int} -0-based index of default device to use, by default 0 (first available device). +0-based index of default device to use. +By default, the value of the environment variable `REACTANT_DEFAULT_DEVICE` is used when set to a non-negative integer, otherwise it is set to 0 (first available device). """ const DEFAULT_DEVICE = Ref{Int}(0) diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index 5fa34457b1..31335f2803 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -158,6 +158,13 @@ function __init__() 1 end + if haskey(ENV, "REACTANT_DEFAULT_DEVICE") + DEFAULT_DEVICE[] = max( + 0, something(tryparse(Int, ENV["REACTANT_DEFAULT_DEVICE"]), 0) + ) + @debug "REACTANT_DEFAULT_DEVICE: " DEFAULT_DEVICE[] maxlog = 1 + end + @debug "REACTANT_XLA_RUNTIME: " REACTANT_XLA_RUNTIME maxlog = 1 @ccall MLIR.API.mlir_c.RegisterEnzymeXLACPUHandler()::Cvoid