@@ -67,14 +67,56 @@ struct Token
6767 mlir_data:: MLIR.IR.Value
6868end
6969
70+ function activate_constant_context! (blk:: MLIR.IR.Block )
71+ stack = get! (task_local_storage (), :entry_block ) do
72+ return Tuple{MLIR. IR. Block,Dict{MLIR. IR. Attribute,TracedRArray}}[]
73+ end
74+ Base. push! (stack, (blk, Dict {MLIR.IR.Attribute,TracedRArray} ()))
75+ return nothing
76+ end
77+
78+ function constant_context (; throw_error:: Core.Bool = true )
79+ return last (task_local_storage (:entry_block ))
80+ end
81+
82+ function deactivate_constant_context! (blk:: MLIR.IR.Block )
83+ constant_context ()[1 ] == blk || error (" Deactivating wrong block" )
84+ return Base. pop! (task_local_storage (:entry_block ))
85+ end
86+
7087# constant ops
7188@noinline function constant (
7289 x:: DenseArray{T,N} ; location= mlir_stacktrace (" constant" , @__FILE__ , @__LINE__ )
7390) where {T,N}
7491 value = MLIR. IR. DenseElementsAttribute (x)
75- output = mlir_type (TracedRArray{T,N}, size (x))
76- res = MLIR. IR. result (stablehlo. constant (; output, value, location))
77- return TracedRArray {T,N} ((), res, size (x))
92+ constants = constant_context ()[2 ]
93+ if haskey (constants, value)
94+ return constants[value]
95+ else
96+ output = mlir_type (TracedRArray{T,N}, size (x))
97+
98+ op_ty_results = MLIR. IR. Type[output]
99+ operands = MLIR. IR. Value[]
100+ owned_regions = MLIR. IR. Region[]
101+ successors = MLIR. IR. Block[]
102+ attributes = MLIR. IR. NamedAttribute[MLIR. Dialects. namedattribute (" value" , value),]
103+
104+ cstop = MLIR. IR. create_operation (
105+ " stablehlo.constant" ,
106+ location;
107+ operands,
108+ owned_regions,
109+ successors,
110+ attributes,
111+ results= op_ty_results,
112+ result_inference= false ,
113+ )
114+
115+ res = MLIR. IR. result (cstop)
116+ tres = TracedRArray {T,N} ((), res, size (x))
117+ constants[value] = tres
118+ return tres
119+ end
78120end
79121
80122@noinline function constant (
@@ -1764,6 +1806,7 @@ end
17641806 true_fn_args = true_fn_names[1 ]
17651807
17661808 MLIR. IR. activate! (true_fn_body)
1809+ Ops. activate_constant_context! (true_fn_body)
17671810 tb_result = try
17681811 for (i, arg) in enumerate (tb_linear_args)
17691812 # find the right path to index the traced arg.
@@ -1787,6 +1830,7 @@ end
17871830 end
17881831 Reactant. call_with_reactant (true_fn, tb_traced_args... )
17891832 finally
1833+ Ops. deactivate_constant_context! (true_fn_body)
17901834 MLIR. IR. deactivate! (true_fn_body)
17911835 end
17921836
@@ -1827,6 +1871,7 @@ end
18271871
18281872 false_fn_args = false_fn_names[1 ]
18291873 MLIR. IR. activate! (false_fn_body)
1874+ Ops. activate_constant_context! (false_fn_body)
18301875 fb_result = try
18311876 for (i, arg) in enumerate (fb_linear_args)
18321877 # find the right path to index the traced arg.
@@ -1850,6 +1895,7 @@ end
18501895 end
18511896 Reactant. call_with_reactant (false_fn, fb_traced_args... )
18521897 finally
1898+ Ops. deactivate_constant_context! (false_fn_body)
18531899 MLIR. IR. deactivate! (false_fn_body)
18541900 end
18551901
@@ -1928,6 +1974,7 @@ end
19281974
19291975 # finalize the true branch by adding the missing values
19301976 MLIR. IR. activate! (true_fn_body)
1977+ Ops. activate_constant_context! (true_fn_body)
19311978 tb_corrected_linear_results = Reactant. TracedType[]
19321979 try
19331980 for (i, path) in enumerate (tb_paths)
@@ -1939,10 +1986,12 @@ end
19391986 end
19401987 finally
19411988 MLIR. IR. deactivate! (true_fn_body)
1989+ Ops. deactivate_constant_context! (true_fn_body)
19421990 end
19431991
19441992 # finalize the false branch by adding the missing values
19451993 MLIR. IR. activate! (false_fn_body)
1994+ Ops. activate_constant_context! (false_fn_body)
19461995 fb_corrected_linear_results = Reactant. TracedType[]
19471996 try
19481997 for (i, path) in enumerate (fb_paths)
@@ -1954,6 +2003,7 @@ end
19542003 end
19552004 finally
19562005 MLIR. IR. deactivate! (false_fn_body)
2006+ Ops. deactivate_constant_context! (false_fn_body)
19572007 end
19582008
19592009 # All MissingTracedValues must be replaced with zeroes
@@ -1968,19 +2018,23 @@ end
19682018 res = if tr isa MissingTracedValue
19692019 @assert ! (fr isa MissingTracedValue)
19702020 MLIR. IR. activate! (true_fn_body)
2021+ Ops. activate_constant_context! (true_fn_body)
19712022 try
19722023 tb_corrected_linear_results[i] = zero (fr)
19732024 finally
19742025 MLIR. IR. deactivate! (true_fn_body)
2026+ Ops. deactivate_constant_context! (true_fn_body)
19752027 end
19762028 fr
19772029 elseif fr isa MissingTracedValue
19782030 @assert ! (tr isa MissingTracedValue)
19792031 MLIR. IR. activate! (false_fn_body)
2032+ Ops. activate_constant_context! (false_fn_body)
19802033 try
19812034 fb_corrected_linear_results[i] = zero (tr)
19822035 finally
19832036 MLIR. IR. deactivate! (false_fn_body)
2037+ Ops. deactivate_constant_context! (false_fn_body)
19842038 end
19852039 tr
19862040 else
@@ -1993,6 +2047,7 @@ end
19932047 end
19942048
19952049 MLIR. IR. activate! (true_fn_body)
2050+ Ops. activate_constant_context! (true_fn_body)
19962051 try
19972052 vals = MLIR. IR. Value[
19982053 Reactant. TracedUtils. get_mlir_data (res) for
@@ -2001,9 +2056,11 @@ end
20012056 MLIR. Dialects. stablehlo. return_ (vals)
20022057 finally
20032058 MLIR. IR. deactivate! (true_fn_body)
2059+ Ops. deactivate_constant_context! (true_fn_body)
20042060 end
20052061
20062062 MLIR. IR. activate! (false_fn_body)
2063+ Ops. activate_constant_context! (false_fn_body)
20072064 try
20082065 vals = MLIR. IR. Value[
20092066 Reactant. TracedUtils. get_mlir_data (res) for
@@ -2012,6 +2069,7 @@ end
20122069 MLIR. Dialects. stablehlo. return_ (vals)
20132070 finally
20142071 MLIR. IR. deactivate! (false_fn_body)
2072+ Ops. deactivate_constant_context! (false_fn_body)
20152073 end
20162074
20172075 # With the corrected results, we can compile the true and false branches
0 commit comments