@@ -66,40 +66,4 @@ TEST(Partitioning, ComputeMobileNetFallbackGraphCorrectly) {
6666 auto trt_results = trt_mod.forward (trt_inputs_ivalues).toTensor ();
6767 ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results, trt_results, 2e-6 ));
6868}
69-
70- TEST (Partitioning, ComputeResNet50HalfFallbackGraphCorrectly) {
71- torch::jit::script::Module mod;
72- try {
73- mod = torch::jit::load (" tests/modules/resnet50_traced.jit.pt" );
74- } catch (const c10::Error& e) {
75- std::cerr << " error loading the model\n " ;
76- return ;
77- }
78-
79- mod.to (torch::kHalf );
80-
81- const std::vector<std::vector<int64_t >> input_shapes = {{1 , 3 , 224 , 224 }};
82- std::vector<torch::jit::IValue> jit_inputs_ivalues;
83- std::vector<torch::jit::IValue> trt_inputs_ivalues;
84- for (auto in_shape : input_shapes) {
85- auto in = at::randint (5 , in_shape, {at::kCUDA }).to (torch::kHalf );
86- jit_inputs_ivalues.push_back (in.clone ());
87- trt_inputs_ivalues.push_back (in.clone ());
88- }
89-
90- auto in_shape = torch_tensorrt::core::ir::Input ({1 , 3 , 224 , 224 });
91- in_shape.dtype = nvinfer1::DataType::kHALF ;
92-
93- std::vector<torch_tensorrt::core::ir::Input> input_ranges ({in_shape});
94- auto g = mod.get_method (" forward" ).graph ();
95- torch_tensorrt::core::CompileSpec cfg (input_ranges);
96- cfg.partition_info .enabled = true ;
97- cfg.partition_info .forced_fallback_operators .push_back (" aten::add" );
98-
99- auto jit_results = mod.forward (jit_inputs_ivalues).toTensor ();
100- auto trt_mod = torch_tensorrt::core::CompileGraph (mod, cfg);
101- auto trt_results = trt_mod.forward (trt_inputs_ivalues).toTensor ();
102- // Lower threshold because FP16
103- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results, trt_results, 2e-1 ));
104- }
10569#endif
0 commit comments