@@ -132,6 +132,7 @@ def invoke_func(*torch_inputs):
132132 return result
133133
134134 def _export_run (self , artifact : torch .nn .Module , trace : Trace ) -> Trace :
135+ artifact .eval ()
135136 result : Trace = []
136137 for item in trace :
137138 prog : ExportedProgram = torch .export .export (
@@ -141,6 +142,7 @@ def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
141142 prog ,
142143 output_type = self ._output_type ,
143144 func_name = artifact .__class__ .__name__ ,
145+ experimental_support_mutation = True ,
144146 # While the current e2e tests don't exercise symbolic shapes,
145147 # enabling this here ensures they don't regress either.
146148 import_symbolic_shape_expressions = True ,
@@ -149,14 +151,7 @@ def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
149151 )
150152 module = self ._backend .compile (module )
151153 backend_module = self ._backend .load (module )
152- params = {
153- # **dict(artifact.named_parameters(remove_duplicate=False)),
154- ** dict (artifact .named_buffers (remove_duplicate = False )),
155- }
156- params_flat , params_spec = pytree .tree_flatten (params )
157- params_flat = list (params_flat )
158- with torch .no_grad ():
159- numpy_inputs = recursively_convert_to_numpy (params_flat + item .inputs )
154+ numpy_inputs = recursively_convert_to_numpy (item .inputs )
160155 outputs = getattr (backend_module , artifact .__class__ .__name__ )(
161156 * numpy_inputs
162157 )
0 commit comments