Skip to content

Commit f584469

Browse files
committed
Update fx importer testing route to simplify input handling.
Signed-off-by: zjgarvey <[email protected]>
1 parent 95d0eb0 commit f584469

File tree

2 files changed

+4
-9
lines changed

2 files changed

+4
-9
lines changed

projects/e2e/torch_mlir_e2e_test/configs/fx_importer_backend.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def invoke_func(*torch_inputs):
132132
return result
133133

134134
def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
135+
artifact.eval()
135136
result: Trace = []
136137
for item in trace:
137138
prog: ExportedProgram = torch.export.export(
@@ -141,6 +142,7 @@ def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
141142
prog,
142143
output_type=self._output_type,
143144
func_name=artifact.__class__.__name__,
145+
experimental_support_mutation=True,
144146
# While the current e2e tests don't exercise symbolic shapes,
145147
# enabling this here ensures they don't regress either.
146148
import_symbolic_shape_expressions=True,
@@ -149,14 +151,7 @@ def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
149151
)
150152
module = self._backend.compile(module)
151153
backend_module = self._backend.load(module)
152-
params = {
153-
# **dict(artifact.named_parameters(remove_duplicate=False)),
154-
**dict(artifact.named_buffers(remove_duplicate=False)),
155-
}
156-
params_flat, params_spec = pytree.tree_flatten(params)
157-
params_flat = list(params_flat)
158-
with torch.no_grad():
159-
numpy_inputs = recursively_convert_to_numpy(params_flat + item.inputs)
154+
numpy_inputs = recursively_convert_to_numpy(item.inputs)
160155
outputs = getattr(backend_module, artifact.__class__.__name__)(
161156
*numpy_inputs
162157
)

projects/e2e/torch_mlir_e2e_test/configs/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
def recursively_convert_to_numpy(o: Any):
1313
if isinstance(o, torch.Tensor):
14-
return o.numpy()
14+
return o.detach().numpy()
1515
if isinstance(o, tuple):
1616
return tuple(recursively_convert_to_numpy(x) for x in o)
1717
if isinstance(o, list):

0 commit comments

Comments
 (0)