diff --git a/tools/pnnx/src/pass_level1/nn_Upsample.cpp b/tools/pnnx/src/pass_level1/nn_Upsample.cpp index 238b336df395..7757d80e8500 100644 --- a/tools/pnnx/src/pass_level1/nn_Upsample.cpp +++ b/tools/pnnx/src/pass_level1/nn_Upsample.cpp @@ -21,13 +21,16 @@ class Upsample : public FuseModulePass void write(Operator* op, const TorchGraphProxy& graph) const { const TorchNodeProxy* upsample_nearest1d = graph.find_node_by_kind("aten::upsample_nearest1d"); + const TorchNodeProxy* upsample_nearest_exact1d = graph.find_node_by_kind("aten::_upsample_nearest_exact1d"); const TorchNodeProxy* upsample_linear1d = graph.find_node_by_kind("aten::upsample_linear1d"); const TorchNodeProxy* upsample_nearest2d = graph.find_node_by_kind("aten::upsample_nearest2d"); + const TorchNodeProxy* upsample_nearest_exact2d = graph.find_node_by_kind("aten::_upsample_nearest_exact2d"); const TorchNodeProxy* upsample_bilinear2d = graph.find_node_by_kind("aten::upsample_bilinear2d"); const TorchNodeProxy* upsample_bicubic2d = graph.find_node_by_kind("aten::upsample_bicubic2d"); const TorchNodeProxy* upsample_nearest3d = graph.find_node_by_kind("aten::upsample_nearest3d"); + const TorchNodeProxy* upsample_nearest_exact3d = graph.find_node_by_kind("aten::_upsample_nearest_exact3d"); const TorchNodeProxy* upsample_trilinear3d = graph.find_node_by_kind("aten::upsample_trilinear3d"); const TorchNodeProxy* upsample = 0; @@ -36,6 +39,11 @@ class Upsample : public FuseModulePass upsample = upsample_nearest1d; op->params["mode"] = "nearest"; } + else if (upsample_nearest_exact1d) + { + upsample = upsample_nearest_exact1d; + op->params["mode"] = "nearest-exact"; + } else if (upsample_linear1d) { upsample = upsample_linear1d; @@ -46,6 +54,11 @@ class Upsample : public FuseModulePass upsample = upsample_nearest2d; op->params["mode"] = "nearest"; } + else if (upsample_nearest_exact2d) + { + upsample = upsample_nearest_exact2d; + op->params["mode"] = "nearest-exact"; + } else if (upsample_bilinear2d) { upsample = upsample_bilinear2d; @@ -61,6 +74,11 @@ class Upsample : public FuseModulePass upsample = upsample_nearest3d; op->params["mode"] = "nearest"; } + else if (upsample_nearest_exact3d) + { + upsample = upsample_nearest_exact3d; + op->params["mode"] = "nearest-exact"; + } else if (upsample_trilinear3d) { upsample = upsample_trilinear3d; diff --git a/tools/pnnx/src/pass_level2/F_interpolate.cpp b/tools/pnnx/src/pass_level2/F_interpolate.cpp index fb6ae304cc1e..069da73cfab9 100644 --- a/tools/pnnx/src/pass_level2/F_interpolate.cpp +++ b/tools/pnnx/src/pass_level2/F_interpolate.cpp @@ -650,6 +650,141 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_7, 110) +class F_interpolate_nearest_exact1d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 size value=None +prim::Constant op_1 0 1 scale_factor value=%scale_factor +aten::_upsample_nearest_exact1d op_2 3 1 input size scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.interpolate"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["scale_factor"] = captured_params.at("scale_factor"); + op->params["mode"] = "nearest-exact"; + op->params["recompute_scale_factor"] = false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_nearest_exact1d, 110) + +class F_interpolate_nearest_exact1d_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 size value=%size +prim::Constant op_1 0 1 scale_factor value=None +aten::_upsample_nearest_exact1d op_2 3 1 input size scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.interpolate"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["size"] = captured_params.at("size"); + op->params["mode"] = "nearest-exact"; + op->params["recompute_scale_factor"] = false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_nearest_exact1d_1, 110) + +class F_interpolate_nearest_exact2d : public F_interpolate_nearest_exact1d +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 size value=None +prim::Constant op_1 0 1 scale_factor value=%scale_factor +aten::_upsample_nearest_exact2d op_2 3 1 input size scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_nearest_exact2d, 110) + +class F_interpolate_nearest_exact2d_1 : public F_interpolate_nearest_exact1d_1 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 size value=%size +prim::Constant op_1 0 1 scale_factor_h value=None +prim::Constant op_2 0 1 scale_factor_w value=None +aten::_upsample_nearest_exact2d op_3 4 1 input size scale_factor_h scale_factor_w out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_nearest_exact2d_1, 110) + +class F_interpolate_nearest_exact3d : public F_interpolate_nearest_exact1d +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 size value=None +prim::Constant op_1 0 1 scale_factor value=%scale_factor +aten::_upsample_nearest_exact3d op_2 3 1 input size scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_nearest_exact3d, 110) + +class F_interpolate_nearest_exact3d_1 : public F_interpolate_nearest_exact1d_1 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 size value=%size +prim::Constant op_1 0 1 scale_factor_d value=None +prim::Constant op_2 0 1 scale_factor_h value=None +prim::Constant op_3 0 1 scale_factor_w value=None +aten::_upsample_nearest_exact3d op_4 5 1 input size scale_factor_d scale_factor_h scale_factor_w out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_nearest_exact3d_1, 110) + class F_interpolate_onnx : public GraphRewriterPass { public: @@ -1004,7 +1139,7 @@ static void linear_coeffs(int w, int outw, bool align_corner, std::vector& fx = (float)(dx * scale); } - int sx = (float)floor(fx); + int sx = (int)floor(fx); fx -= sx; ia[dx] = sx; @@ -1154,4 +1289,226 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_onnx_1d_linear, 110) +static bool resolve_nearest_exact_1d(int w, int outw, const int64_t* pindex) +{ + double scale = (double)w / outw; + for (int i = 0; i < outw; i++) + { + float fx = (float)((i + 0.5f) * scale); + int sx = (int)floor(fx); + + if (pindex[i] != sx) + return false; + } + + return true; +} + +class F_interpolate_onnx_1d_nearest_exact : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input 0 1 input +Tensor.permute op_0 1 1 input pnnx_4 dims=(2,0,1) +pnnx.Attribute op_1 0 1 index @data=(%size,1)i64 +GatherND op_2 2 1 pnnx_4 index pnnx_5 batch_dims=0 +Tensor.permute op_3 1 1 pnnx_5 out dims=(1,2,0) +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.interpolate"; + } + + bool match(const std::map& matched_operators, const std::map& captured_params, const std::map& captured_attrs) const + { + const int size = captured_params.at("size").i; + + auto index = captured_attrs.at("op_1.data"); + + const int64_t* pindex = (const int64_t*)index.data.data(); + + const int w = matched_operators.at("op_0")->inputs[0]->shape[2]; + + bool nearest_exact = resolve_nearest_exact_1d(w, size, pindex); + + return nearest_exact; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int size = captured_params.at("size").i; + op->params["size"] = {size}; + op->params["mode"] = "nearest-exact"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_onnx_1d_nearest_exact, 111) + +static bool resolve_nearest_exact_2d(int w, int h, int outw, int outh, const int64_t* pindex) +{ + double scale_w = (double)w / outw; + double scale_h = (double)h / outh; + for (int i = 0; i < outh; i++) + { + float fy = (float)((i + 0.5f) * scale_h); + int sy = (int)floor(fy); + + for (int j = 0; j < outw; j++) + { + float fx = (float)((j + 0.5f) * scale_w); + int sx = (int)floor(fx); + + int py = pindex[0]; + int px = pindex[1]; + pindex += 2; + + if (px != sx || py != sy) + return false; + } + } + + return true; +} + +class F_interpolate_onnx_2d_nearest_exact : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input 0 1 input +Tensor.permute op_0 1 1 input pnnx_48 dims=(2,3,0,1) +pnnx.Attribute op_1 0 1 index @data=(%size_h,%size_w,2)i64 +GatherND op_2 2 1 pnnx_48 index pnnx_49 batch_dims=0 +Tensor.permute op_3 1 1 pnnx_49 out dims=(2,3,0,1) +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.interpolate"; + } + + bool match(const std::map& matched_operators, const std::map& captured_params, const std::map& captured_attrs) const + { + const int size_h = captured_params.at("size_h").i; + const int size_w = captured_params.at("size_w").i; + + auto index = captured_attrs.at("op_1.data"); + + const int64_t* pindex = (const int64_t*)index.data.data(); + + const int h = matched_operators.at("op_0")->inputs[0]->shape[2]; + const int w = matched_operators.at("op_0")->inputs[0]->shape[3]; + + bool nearest_exact = resolve_nearest_exact_2d(w, h, size_w, size_h, pindex); + + return nearest_exact; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int size_h = captured_params.at("size_h").i; + const int size_w = captured_params.at("size_w").i; + op->params["size"] = {size_h, size_w}; + op->params["mode"] = "nearest-exact"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_onnx_2d_nearest_exact, 111) + +static bool resolve_nearest_exact_3d(int w, int h, int d, int outw, int outh, int outd, const int64_t* pindex) +{ + double scale_w = (double)w / outw; + double scale_h = (double)h / outh; + double scale_d = (double)d / outd; + for (int i = 0; i < outd; i++) + { + float fz = (float)((i + 0.5f) * scale_d); + int sz = (int)floor(fz); + + for (int j = 0; j < outh; j++) + { + float fy = (float)((j + 0.5f) * scale_h); + int sy = (int)floor(fy); + + for (int k = 0; k < outw; k++) + { + float fx = (float)((k + 0.5f) * scale_w); + int sx = (int)floor(fx); + + int pz = pindex[0]; + int py = pindex[1]; + int px = pindex[2]; + pindex += 3; + + if (px != sx || py != sy || pz != sz) + return false; + } + } + } + + return true; +} + +class F_interpolate_onnx_3d_nearest_exact : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input 0 1 input +Tensor.permute op_0 1 1 input pnnx_48 dims=(2,3,4,0,1) +pnnx.Attribute op_1 0 1 index @data=(%size_d,%size_h,%size_w,3)i64 +GatherND op_2 2 1 pnnx_48 index pnnx_49 batch_dims=0 +Tensor.permute op_3 1 1 pnnx_49 out dims=(3,4,0,1,2) +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.interpolate"; + } + + bool match(const std::map& matched_operators, const std::map& captured_params, const std::map& captured_attrs) const + { + const int size_d = captured_params.at("size_d").i; + const int size_h = captured_params.at("size_h").i; + const int size_w = captured_params.at("size_w").i; + + auto index = captured_attrs.at("op_1.data"); + + const int64_t* pindex = (const int64_t*)index.data.data(); + + const int d = matched_operators.at("op_0")->inputs[0]->shape[2]; + const int h = matched_operators.at("op_0")->inputs[0]->shape[3]; + const int w = matched_operators.at("op_0")->inputs[0]->shape[4]; + + bool nearest_exact = resolve_nearest_exact_3d(w, h, d, size_w, size_h, size_d, pindex); + + return nearest_exact; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int size_d = captured_params.at("size_d").i; + const int size_h = captured_params.at("size_h").i; + const int size_w = captured_params.at("size_w").i; + op->params["size"] = {size_d, size_h, size_w}; + op->params["mode"] = "nearest-exact"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_onnx_3d_nearest_exact, 111) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_interpolate.cpp b/tools/pnnx/src/pass_ncnn/F_interpolate.cpp index 8b3426a3d41b..1568234f4f23 100644 --- a/tools/pnnx/src/pass_ncnn/F_interpolate.cpp +++ b/tools/pnnx/src/pass_ncnn/F_interpolate.cpp @@ -45,6 +45,11 @@ pnnx.Output output 1 0 out if (mode == "nearest") op->params["0"] = 1; + if (mode == "nearest-exact") + { + fprintf(stderr, "unsupported interpolate mode nearest-exact\n"); + op->params["0"] = 1; + } if (mode == "bilinear" || mode == "linear") op->params["0"] = 2; if (mode == "bicubic") @@ -109,6 +114,11 @@ pnnx.Output output 1 0 out if (mode == "nearest") op->params["0"] = 1; + if (mode == "nearest-exact") + { + fprintf(stderr, "unsupported interpolate mode nearest-exact\n"); + op->params["0"] = 1; + } if (mode == "bilinear" || mode == "linear") op->params["0"] = 2; if (mode == "bicubic") @@ -173,6 +183,11 @@ pnnx.Output output 1 0 out if (mode == "nearest") op->params["0"] = 1; + if (mode == "nearest-exact") + { + fprintf(stderr, "unsupported interpolate mode nearest-exact\n"); + op->params["0"] = 1; + } if (mode == "bilinear" || mode == "linear") op->params["0"] = 2; if (mode == "bicubic") @@ -237,6 +252,11 @@ pnnx.Output output 1 0 out if (mode == "nearest") op->params["0"] = 1; + if (mode == "nearest-exact") + { + fprintf(stderr, "unsupported interpolate mode nearest-exact\n"); + op->params["0"] = 1; + } if (mode == "bilinear" || mode == "linear") op->params["0"] = 2; if (mode == "bicubic") diff --git a/tools/pnnx/src/pass_ncnn/nn_Upsample.cpp b/tools/pnnx/src/pass_ncnn/nn_Upsample.cpp index 5f1b237304f3..9ff8e72c7490 100644 --- a/tools/pnnx/src/pass_ncnn/nn_Upsample.cpp +++ b/tools/pnnx/src/pass_ncnn/nn_Upsample.cpp @@ -38,6 +38,11 @@ pnnx.Output output 1 0 out if (mode == "nearest") op->params["0"] = 1; + if (mode == "nearest-exact") + { + fprintf(stderr, "unsupported interpolate mode nearest-exact\n"); + op->params["0"] = 1; + } if (mode == "bilinear" || mode == "linear") op->params["0"] = 2; if (mode == "bicubic") @@ -104,6 +109,11 @@ pnnx.Output output 1 0 out if (mode == "nearest") op->params["0"] = 1; + if (mode == "nearest-exact") + { + fprintf(stderr, "unsupported interpolate mode nearest-exact\n"); + op->params["0"] = 1; + } if (mode == "bilinear" || mode == "linear") op->params["0"] = 2; if (mode == "bicubic") @@ -161,6 +171,11 @@ pnnx.Output output 1 0 out if (mode == "nearest") op->params["0"] = 1; + if (mode == "nearest-exact") + { + fprintf(stderr, "unsupported interpolate mode nearest-exact\n"); + op->params["0"] = 1; + } if (mode == "bilinear" || mode == "linear") op->params["0"] = 2; if (mode == "bicubic") @@ -227,6 +242,11 @@ pnnx.Output output 1 0 out if (mode == "nearest") op->params["0"] = 1; + if (mode == "nearest-exact") + { + fprintf(stderr, "unsupported interpolate mode nearest-exact\n"); + op->params["0"] = 1; + } if (mode == "bilinear" || mode == "linear") op->params["0"] = 2; if (mode == "bicubic") diff --git a/tools/pnnx/src/utils.cpp b/tools/pnnx/src/utils.cpp index 4bf9ee77dfd0..6befd6dfdf18 100644 --- a/tools/pnnx/src/utils.cpp +++ b/tools/pnnx/src/utils.cpp @@ -126,21 +126,24 @@ std::string float_to_string(float f) return std::string(buffer); } - const int len = snprintf(buffer, sizeof(buffer), "%.8f", f); + const int len = snprintf(buffer, sizeof(buffer), "%g", f); - // remove tail zeros - char* end = buffer + len - 1; - while (end > buffer && *end == '0') + bool is_integer = true; + for (int i = 0; i < len; i++) { - *end = '\0'; - end--; + if (buffer[i] == '.' || buffer[i] == 'e' || buffer[i] == 'E') + { + is_integer = false; + break; + } } // maintain point-zero - if (*end == '.') + if (is_integer) { - *(end + 1) = '0'; - *(end + 2) = '\0'; + buffer[len] = '.'; + buffer[len + 1] = '0'; + buffer[len + 2] = '\0'; } return std::string(buffer); @@ -160,21 +163,24 @@ std::string double_to_string(double d) return std::string(buffer); } - const int len = snprintf(buffer, sizeof(buffer), "%.17f", d); + const int len = snprintf(buffer, sizeof(buffer), "%g", d); - // remove tail zeros - char* end = buffer + len - 1; - while (end > buffer && *end == '0') + bool is_integer = true; + for (int i = 0; i < len; i++) { - *end = '\0'; - end--; + if (buffer[i] == '.' || buffer[i] == 'e' || buffer[i] == 'E') + { + is_integer = false; + break; + } } // maintain point-zero - if (*end == '.') + if (is_integer) { - *(end + 1) = '0'; - *(end + 2) = '\0'; + buffer[len] = '.'; + buffer[len + 1] = '0'; + buffer[len + 2] = '\0'; } return std::string(buffer); diff --git a/tools/pnnx/tests/onnx/test_F_interpolate.py b/tools/pnnx/tests/onnx/test_F_interpolate.py index 4771840748e0..84d3f63f12da 100644 --- a/tools/pnnx/tests/onnx/test_F_interpolate.py +++ b/tools/pnnx/tests/onnx/test_F_interpolate.py @@ -39,12 +39,16 @@ def forward(self, x, y, z, w): z3 = F.interpolate(z3, scale_factor=2, mode='trilinear') w = F.interpolate(w, scale_factor=(2.976744,2.976744), mode='nearest', recompute_scale_factor=False) + return x0, x1, x2, y0, y1, y2, y3, z0, z1, z2, z3, w else: x = F.interpolate(x, size=16) x = F.interpolate(x, scale_factor=2, mode='nearest') x = F.interpolate(x, size=(20), mode='nearest') x = F.interpolate(x, scale_factor=(4), mode='nearest') + if version.parse(torch.__version__) >= version.parse('2.9'): + x = F.interpolate(x, size=12, mode='nearest-exact') + 2 + x = F.interpolate(x, scale_factor=(3), mode='nearest-exact') x = F.interpolate(x, size=16, mode='linear') x = F.interpolate(x, scale_factor=2, mode='linear') x = F.interpolate(x, size=(24), mode='linear', align_corners=True) @@ -60,6 +64,9 @@ def forward(self, x, y, z, w): y = F.interpolate(y, scale_factor=(4,4), mode='nearest') y = F.interpolate(y, size=(16,24), mode='nearest') y = F.interpolate(y, scale_factor=(2,3), mode='nearest') + if version.parse(torch.__version__) >= version.parse('2.9'): + y = F.interpolate(y, size=(11,12), mode='nearest-exact') + 3 + y = F.interpolate(y, scale_factor=(3,2), mode='nearest-exact') y = F.interpolate(y, size=16, mode='bilinear') y = F.interpolate(y, scale_factor=2, mode='bilinear') y = F.interpolate(y, size=(20,20), mode='bilinear', align_corners=False) @@ -86,6 +93,9 @@ def forward(self, x, y, z, w): z = F.interpolate(z, scale_factor=(4,4,4), mode='nearest') z = F.interpolate(z, size=(16,24,20), mode='nearest') z = F.interpolate(z, scale_factor=(2,3,4), mode='nearest') + if version.parse(torch.__version__) >= version.parse('2.9'): + z = F.interpolate(z, size=(11,12,13), mode='nearest-exact') + 4 + z = F.interpolate(z, scale_factor=(3,1,2), mode='nearest-exact') z = F.interpolate(z, size=16, mode='trilinear') z = F.interpolate(z, scale_factor=2, mode='trilinear') z = F.interpolate(z, size=(20,20,20), mode='trilinear', align_corners=False) diff --git a/tools/pnnx/tests/onnx/test_nn_Upsample.py b/tools/pnnx/tests/onnx/test_nn_Upsample.py index cd1fa0c3101c..cc4f9a802d74 100644 --- a/tools/pnnx/tests/onnx/test_nn_Upsample.py +++ b/tools/pnnx/tests/onnx/test_nn_Upsample.py @@ -81,6 +81,14 @@ def __init__(self): self.up_w = nn.Upsample(scale_factor=(1.499,1.499), mode='nearest') + if version.parse(torch.__version__) >= version.parse('2.9'): + self.up_1d_0_4 = nn.Upsample(size=12, mode='nearest-exact') + self.up_1d_0_5 = nn.Upsample(scale_factor=(3), mode='nearest-exact') + self.up_2d_0_6 = nn.Upsample(size=(11,12), mode='nearest-exact') + self.up_2d_0_7 = nn.Upsample(scale_factor=(3,2), mode='nearest-exact') + self.up_3d_0_6 = nn.Upsample(size=(11,12,13), mode='nearest-exact') + self.up_3d_0_7 = nn.Upsample(scale_factor=(3,1,2), mode='nearest-exact') + def forward(self, x, y, z, w): if version.parse(torch.__version__) < version.parse('1.12'): x0 = self.up_1d_0_0(x) @@ -116,6 +124,9 @@ def forward(self, x, y, z, w): x = self.up_1d_0_1(x) x = self.up_1d_0_2(x) x = self.up_1d_0_3(x) + if version.parse(torch.__version__) >= version.parse('2.9'): + x = self.up_1d_0_4(x) + 2 + x = self.up_1d_0_5(x) x = self.up_1d_1_0(x) x = self.up_1d_1_1(x) x = self.up_1d_1_2(x) @@ -127,6 +138,9 @@ def forward(self, x, y, z, w): y = self.up_2d_0_3(y) y = self.up_2d_0_4(y) y = self.up_2d_0_5(y) + if version.parse(torch.__version__) >= version.parse('2.9'): + y = self.up_2d_0_6(y) + 3 + y = self.up_2d_0_7(y) y = self.up_2d_1_0(y) y = self.up_2d_1_1(y) y = self.up_2d_1_2(y) @@ -146,6 +160,9 @@ def forward(self, x, y, z, w): z = self.up_3d_0_3(z) z = self.up_3d_0_4(z) z = self.up_3d_0_5(z) + if version.parse(torch.__version__) >= version.parse('2.9'): + z = self.up_3d_0_6(z) + 4 + z = self.up_3d_0_7(z) z = self.up_3d_1_0(z) z = self.up_3d_1_1(z) z = self.up_3d_1_2(z) diff --git a/tools/pnnx/tests/test_F_interpolate.py b/tools/pnnx/tests/test_F_interpolate.py index eddecf7404bb..7c9ada92d8b1 100644 --- a/tools/pnnx/tests/test_F_interpolate.py +++ b/tools/pnnx/tests/test_F_interpolate.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from packaging import version class Model(nn.Module): def __init__(self): @@ -14,6 +15,9 @@ def forward(self, x, y, z, w): x = F.interpolate(x, scale_factor=2, mode='nearest') x = F.interpolate(x, size=(20), mode='nearest') x = F.interpolate(x, scale_factor=(4), mode='nearest') + if version.parse(torch.__version__) >= version.parse('1.11'): + x = F.interpolate(x, size=12, mode='nearest-exact') + x = F.interpolate(x, scale_factor=(3), mode='nearest-exact') x = F.interpolate(x, size=16, mode='linear') x = F.interpolate(x, scale_factor=2, mode='linear') x = F.interpolate(x, size=(24), mode='linear', align_corners=True) @@ -29,6 +33,9 @@ def forward(self, x, y, z, w): y = F.interpolate(y, scale_factor=(4,4), mode='nearest') y = F.interpolate(y, size=(16,24), mode='nearest') y = F.interpolate(y, scale_factor=(2,3), mode='nearest') + if version.parse(torch.__version__) >= version.parse('1.11'): + y = F.interpolate(y, size=(11,12), mode='nearest-exact') + y = F.interpolate(y, scale_factor=(3,2), mode='nearest-exact') y = F.interpolate(y, size=16, mode='bilinear') y = F.interpolate(y, scale_factor=2, mode='bilinear') y = F.interpolate(y, size=(20,20), mode='bilinear', align_corners=False) @@ -54,6 +61,9 @@ def forward(self, x, y, z, w): z = F.interpolate(z, scale_factor=(4,4,4), mode='nearest') z = F.interpolate(z, size=(16,24,20), mode='nearest') z = F.interpolate(z, scale_factor=(2,3,4), mode='nearest') + if version.parse(torch.__version__) >= version.parse('1.11'): + z = F.interpolate(z, size=(11,12,13), mode='nearest-exact') + z = F.interpolate(z, scale_factor=(3,1,2), mode='nearest-exact') z = F.interpolate(z, size=16, mode='trilinear') z = F.interpolate(z, scale_factor=2, mode='trilinear') z = F.interpolate(z, size=(20,20,20), mode='trilinear', align_corners=False) diff --git a/tools/pnnx/tests/test_nn_Upsample.py b/tools/pnnx/tests/test_nn_Upsample.py index c6055b99be24..bb539935b8b9 100644 --- a/tools/pnnx/tests/test_nn_Upsample.py +++ b/tools/pnnx/tests/test_nn_Upsample.py @@ -14,6 +14,9 @@ def __init__(self): self.up_1d_0_1 = nn.Upsample(scale_factor=2, mode='nearest') self.up_1d_0_2 = nn.Upsample(size=(20), mode='nearest') self.up_1d_0_3 = nn.Upsample(scale_factor=(4), mode='nearest') + if version.parse(torch.__version__) >= version.parse('1.11'): + self.up_1d_0_4 = nn.Upsample(size=12, mode='nearest-exact') + self.up_1d_0_5 = nn.Upsample(scale_factor=(3), mode='nearest-exact') self.up_1d_1_0 = nn.Upsample(size=16, mode='linear') self.up_1d_1_1 = nn.Upsample(scale_factor=2, mode='linear') self.up_1d_1_2 = nn.Upsample(size=(24), mode='linear', align_corners=True) @@ -31,6 +34,9 @@ def __init__(self): self.up_2d_0_3 = nn.Upsample(scale_factor=(4,4), mode='nearest') self.up_2d_0_4 = nn.Upsample(size=(16,24), mode='nearest') self.up_2d_0_5 = nn.Upsample(scale_factor=(2,3), mode='nearest') + if version.parse(torch.__version__) >= version.parse('1.11'): + self.up_2d_0_6 = nn.Upsample(size=(11,12), mode='nearest-exact') + self.up_2d_0_7 = nn.Upsample(scale_factor=(3,2), mode='nearest-exact') self.up_2d_1_0 = nn.Upsample(size=16, mode='bilinear') self.up_2d_1_1 = nn.Upsample(scale_factor=2, mode='bilinear') self.up_2d_1_2 = nn.Upsample(size=(20,20), mode='bilinear', align_corners=False) @@ -59,6 +65,9 @@ def __init__(self): self.up_3d_0_3 = nn.Upsample(scale_factor=(4,4,4), mode='nearest') self.up_3d_0_4 = nn.Upsample(size=(16,24,20), mode='nearest') self.up_3d_0_5 = nn.Upsample(scale_factor=(2,3,4), mode='nearest') + if version.parse(torch.__version__) >= version.parse('1.11'): + self.up_3d_0_6 = nn.Upsample(size=(11,12,13), mode='nearest-exact') + self.up_3d_0_7 = nn.Upsample(scale_factor=(3,1,2), mode='nearest-exact') self.up_3d_1_0 = nn.Upsample(size=16, mode='trilinear') self.up_3d_1_1 = nn.Upsample(scale_factor=2, mode='trilinear') self.up_3d_1_2 = nn.Upsample(size=(20,20,20), mode='trilinear', align_corners=False) @@ -76,6 +85,9 @@ def forward(self, x, y, z, w): x = self.up_1d_0_1(x) x = self.up_1d_0_2(x) x = self.up_1d_0_3(x) + if version.parse(torch.__version__) >= version.parse('1.11'): + x = self.up_1d_0_4(x) + x = self.up_1d_0_5(x) x = self.up_1d_1_0(x) x = self.up_1d_1_1(x) x = self.up_1d_1_2(x) @@ -87,6 +99,9 @@ def forward(self, x, y, z, w): y = self.up_2d_0_3(y) y = self.up_2d_0_4(y) y = self.up_2d_0_5(y) + if version.parse(torch.__version__) >= version.parse('1.11'): + y = self.up_2d_0_6(y) + y = self.up_2d_0_7(y) y = self.up_2d_1_0(y) y = self.up_2d_1_1(y) y = self.up_2d_1_2(y) @@ -106,6 +121,9 @@ def forward(self, x, y, z, w): z = self.up_3d_0_3(z) z = self.up_3d_0_4(z) z = self.up_3d_0_5(z) + if version.parse(torch.__version__) >= version.parse('1.11'): + z = self.up_3d_0_6(z) + z = self.up_3d_0_7(z) z = self.up_3d_1_0(z) z = self.up_3d_1_1(z) z = self.up_3d_1_2(z)