Skip to content

Commit 8903f44

Browse files
authored
pnnx fuse more transformer attention variants (#6371)
* pnnx fuse more transformer attention variants * test new transformers * always eliminate contiguous * unified view and reshape * better float param string
1 parent de84c3e commit 8903f44

File tree

82 files changed

+1864
-773
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+1864
-773
lines changed

.github/workflows/pnnx.yml

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -234,22 +234,22 @@ jobs:
234234
fail-fast: false
235235
matrix:
236236
include:
237-
- { python: '3.8', numpy: '1.24.4', opencv: '4.5.*', torch: '1.8.1', torchvision: '0.9.1', torchaudio: '0.8.1' }
238-
- { python: '3.8', numpy: '1.24.4', opencv: '4.5.*', torch: '1.9.1', torchvision: '0.10.1', torchaudio: '0.9.1' }
239-
- { python: '3.8', numpy: '1.24.4', opencv: '4.6.*', torch: '1.10.0', torchvision: '0.11.1', torchaudio: '0.10.0+cpu' }
240-
- { python: '3.9', numpy: '1.26.4', opencv: '4.6.*', torch: '1.11.0', torchvision: '0.12.0', torchaudio: '0.11.0+cpu' }
241-
- { python: '3.9', numpy: '1.26.4', opencv: '4.7.*', torch: '1.12.0', torchvision: '0.13.0', torchaudio: '0.12.0+cpu' }
242-
- { python: '3.10', numpy: '1.26.4', opencv: '4.7.*', torch: '1.13.0', torchvision: '0.14.0', torchaudio: '0.13.0+cpu' }
243-
- { python: '3.10', numpy: '1.26.4', opencv: '4.8.*', torch: '2.0.0', torchvision: '0.15.1', torchaudio: '2.0.0+cpu' }
244-
- { python: '3.10', numpy: '1.26.4', opencv: '4.8.*', torch: '2.1.0', torchvision: '0.16.0', torchaudio: '2.1.0+cpu' }
245-
- { python: '3.11', numpy: '1.26.4', opencv: '4.9.*', torch: '2.2.1', torchvision: '0.17.1', torchaudio: '2.2.1+cpu' }
246-
- { python: '3.11', numpy: '1.26.4', opencv: '4.9.*', torch: '2.3.0', torchvision: '0.18.0', torchaudio: '2.3.0+cpu' }
247-
- { python: '3.11', numpy: '2.2.5', opencv: '4.10.*', torch: '2.4.0', torchvision: '0.19.0', torchaudio: '2.4.0+cpu' }
248-
- { python: '3.12', numpy: '2.2.5', opencv: '4.10.*', torch: '2.5.0', torchvision: '0.20.0', torchaudio: '2.5.0+cpu' }
249-
- { python: '3.12', numpy: '2.2.5', opencv: '4.11.*', torch: '2.6.0', torchvision: '0.21.0', torchaudio: '2.6.0+cpu' }
250-
- { python: '3.12', numpy: '2.2.5', opencv: '4.11.*', torch: '2.7.0', torchvision: '0.22.0', torchaudio: '2.7.0+cpu' }
251-
- { python: '3.13', numpy: '2.2.5', opencv: '4.12.*', torch: '2.8.0', torchvision: '0.23.0', torchaudio: '2.8.0+cpu' }
252-
- { python: '3.13', numpy: '2.2.5', opencv: '4.12.*', torch: '2.9.0', torchvision: '0.24.0', torchaudio: '2.9.0+cpu' }
237+
- { python: '3.8', numpy: '1.24.4', opencv: '4.5.*', torch: '1.8.1', torchvision: '0.9.1', torchaudio: '0.8.1', transformers: '4.52.1' }
238+
- { python: '3.8', numpy: '1.24.4', opencv: '4.5.*', torch: '1.9.1', torchvision: '0.10.1', torchaudio: '0.9.1', transformers: '4.52.1' }
239+
- { python: '3.8', numpy: '1.24.4', opencv: '4.6.*', torch: '1.10.0', torchvision: '0.11.1', torchaudio: '0.10.0+cpu', transformers: '4.52.1' }
240+
- { python: '3.9', numpy: '1.26.4', opencv: '4.6.*', torch: '1.11.0', torchvision: '0.12.0', torchaudio: '0.11.0+cpu', transformers: '4.52.1' }
241+
- { python: '3.9', numpy: '1.26.4', opencv: '4.7.*', torch: '1.12.0', torchvision: '0.13.0', torchaudio: '0.12.0+cpu', transformers: '4.52.1' }
242+
- { python: '3.10', numpy: '1.26.4', opencv: '4.7.*', torch: '1.13.0', torchvision: '0.14.0', torchaudio: '0.13.0+cpu', transformers: '4.52.1' }
243+
- { python: '3.10', numpy: '1.26.4', opencv: '4.8.*', torch: '2.0.0', torchvision: '0.15.1', torchaudio: '2.0.0+cpu', transformers: '4.52.1' }
244+
- { python: '3.10', numpy: '1.26.4', opencv: '4.8.*', torch: '2.1.0', torchvision: '0.16.0', torchaudio: '2.1.0+cpu', transformers: '4.52.1' }
245+
- { python: '3.11', numpy: '1.26.4', opencv: '4.9.*', torch: '2.2.1', torchvision: '0.17.1', torchaudio: '2.2.1+cpu', transformers: '4.52.1' }
246+
- { python: '3.11', numpy: '1.26.4', opencv: '4.9.*', torch: '2.3.0', torchvision: '0.18.0', torchaudio: '2.3.0+cpu', transformers: '4.52.1' }
247+
- { python: '3.11', numpy: '2.2.5', opencv: '4.10.*', torch: '2.4.0', torchvision: '0.19.0', torchaudio: '2.4.0+cpu', transformers: '4.52.1' }
248+
- { python: '3.12', numpy: '2.2.5', opencv: '4.10.*', torch: '2.5.0', torchvision: '0.20.0', torchaudio: '2.5.0+cpu', transformers: '4.52.1' }
249+
- { python: '3.12', numpy: '2.2.5', opencv: '4.11.*', torch: '2.6.0', torchvision: '0.21.0', torchaudio: '2.6.0+cpu', transformers: '4.52.1' }
250+
- { python: '3.12', numpy: '2.2.5', opencv: '4.11.*', torch: '2.7.0', torchvision: '0.22.0', torchaudio: '2.7.0+cpu', transformers: '4.52.1' }
251+
- { python: '3.13', numpy: '2.2.5', opencv: '4.12.*', torch: '2.8.0', torchvision: '0.23.0', torchaudio: '2.8.0+cpu', transformers: '4.56.2' }
252+
- { python: '3.13', numpy: '2.2.5', opencv: '4.12.*', torch: '2.9.0', torchvision: '0.24.0', torchaudio: '2.9.0+cpu', transformers: '4.56.2' }
253253

254254
name: test-${{ matrix.torch }}-py${{ matrix.python }}
255255

@@ -319,7 +319,7 @@ jobs:
319319
pip3 install --user pytest wheel twine requests einops numpy==${{ matrix.numpy }} opencv-python==${{ matrix.opencv }}
320320
pip3 install --user torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu torchaudio==${{ matrix.torchaudio }} --index-url https://download.pytorch.org/whl/cpu
321321
pip3 install --user onnx onnxscript onnxruntime
322-
pip3 install --user "transformers<=4.52.1" diffusers "safetensors<=0.6.0"
322+
pip3 install --user "transformers<=${{ matrix.transformers }}" diffusers "safetensors<=0.6.2"
323323
324324
- name: setup-pytorch-execstack-or-patchelf
325325
if: ${{ matrix.python }} == '3.8' || ${{ matrix.python }} == '3.9'

tools/pnnx/src/CMakeLists.txt

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ set(pnnx_pass_level1_SRCS
110110
)
111111

112112
set(pnnx_pass_level2_SRCS
113+
pass_level2/eliminate_contiguous.cpp
113114
pass_level2/eliminate_size_numtotensor_int.cpp
114115
pass_level2/functionize.cpp
115116
pass_level2/fuse_constantlist.cpp
@@ -187,7 +188,6 @@ set(pnnx_pass_level2_SRCS
187188
pass_level2/F_upsample_bilinear.cpp
188189
pass_level2/F_upsample_nearest.cpp
189190
pass_level2/F_upsample.cpp
190-
pass_level2/Tensor_contiguous.cpp
191191
pass_level2/Tensor_copy.cpp
192192
pass_level2/Tensor_expand.cpp
193193
pass_level2/Tensor_expand_as.cpp
@@ -207,7 +207,6 @@ set(pnnx_pass_level2_SRCS
207207
pass_level2/Tensor_slice.cpp
208208
pass_level2/Tensor_to.cpp
209209
pass_level2/Tensor_type_as.cpp
210-
pass_level2/Tensor_view.cpp
211210
pass_level2/torch_addmm.cpp
212211
pass_level2/torch_amax.cpp
213212
pass_level2/torch_amin.cpp
@@ -362,7 +361,7 @@ set(pnnx_pass_level5_SRCS
362361
pass_level5/eliminate_noop_pad.cpp
363362
pass_level5/eliminate_noop_upsample.cpp
364363
pass_level5/eliminate_noop_slice.cpp
365-
pass_level5/eliminate_noop_view_reshape.cpp
364+
pass_level5/eliminate_noop_reshape.cpp
366365
pass_level5/eliminate_reshape_shape_expression.cpp
367366
pass_level5/eliminate_type_as.cpp
368367
pass_level5/eval_expression.cpp
@@ -376,7 +375,6 @@ set(pnnx_pass_level5_SRCS
376375
pass_level5/fuse_convtranspose1d_batchnorm1d.cpp
377376
pass_level5/fuse_convtranspose2d_batchnorm2d.cpp
378377
pass_level5/fuse_convtranspose3d_batchnorm3d.cpp
379-
pass_level5/fuse_contiguous_view.cpp
380378
pass_level5/fuse_linear_batchnorm1d.cpp
381379
pass_level5/fuse_pad_conv1d.cpp
382380
pass_level5/fuse_pad_conv2d.cpp
@@ -583,12 +581,10 @@ set(pnnx_pass_ncnn_SRCS
583581
pass_ncnn/nn_UpsamplingBilinear2d.cpp
584582
pass_ncnn/nn_UpsamplingNearest2d.cpp
585583
pass_ncnn/nn_ZeroPad2d.cpp
586-
pass_ncnn/Tensor_contiguous.cpp
587584
pass_ncnn/Tensor_permute.cpp
588585
pass_ncnn/Tensor_reshape.cpp
589586
pass_ncnn/Tensor_reshape_as.cpp
590587
pass_ncnn/Tensor_repeat.cpp
591-
pass_ncnn/Tensor_view.cpp
592588
pass_ncnn/torch_addmm.cpp
593589
pass_ncnn/torch_amax.cpp
594590
pass_ncnn/torch_amin.cpp

tools/pnnx/src/ir.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1529,7 +1529,8 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath, con
15291529
}
15301530
if (param.type == 3)
15311531
{
1532-
fprintf(pyfp, "%f", param.f);
1532+
std::string fs = float_to_string(param.f);
1533+
fprintf(pyfp, "%s", fs.c_str());
15331534
}
15341535
if (param.type == 4)
15351536
{
@@ -1569,7 +1570,8 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath, con
15691570
fprintf(pyfp, "(");
15701571
for (size_t i = 0; i < param.af.size(); i++)
15711572
{
1572-
fprintf(pyfp, "%f", param.af[i]);
1573+
std::string afs = float_to_string(param.af[i]);
1574+
fprintf(pyfp, "%s", afs.c_str());
15731575
if (i + 1 != param.af.size() || param.af.size() == 1)
15741576
fprintf(pyfp, ",");
15751577
}
@@ -1640,7 +1642,8 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath, con
16401642
}
16411643

16421644
fprintf(pyfp, " self.%s.set_weight_bias(self_%s_weight, self_%s_bias)\n", sanitize_identifier(op->name).c_str(), sanitize_identifier(op->name).c_str(), sanitize_identifier(op->name).c_str());
1643-
fprintf(pyfp, " self.%s.scale = %f\n", sanitize_identifier(op->name).c_str(), op->params.at("scale").f);
1645+
std::string scale_str = float_to_string(op->params.at("scale").f);
1646+
fprintf(pyfp, " self.%s.scale = %s\n", sanitize_identifier(op->name).c_str(), scale_str.c_str());
16441647
fprintf(pyfp, " self.%s.zero_point = %d\n", sanitize_identifier(op->name).c_str(), op->params.at("zero_point").i);
16451648

16461649
continue;
@@ -1857,9 +1860,9 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath, con
18571860
}
18581861
fprintf(pyfp, ")\n");
18591862
}
1860-
else if (op->type == "Tensor.view" || op->type == "Tensor.reshape")
1863+
else if (op->type == "Tensor.reshape")
18611864
{
1862-
// view reshape
1865+
// reshape
18631866
fprintf(pyfp, "v_%s = v_%s.%s(", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str());
18641867
if (op->inputs.size() == 2)
18651868
{
@@ -1887,7 +1890,7 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath, con
18871890
}
18881891
else if (op->type == "Tensor.repeat")
18891892
{
1890-
// view reshape
1893+
// repeat
18911894
fprintf(pyfp, "v_%s = v_%s.%s(", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str());
18921895
if (op->inputs.size() == 2)
18931896
{
@@ -2252,7 +2255,8 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath, con
22522255
{
22532256
for (size_t i = 0; i < op->inputs.size(); i++)
22542257
{
2255-
if (!op->inputnames[i].empty())
2258+
bool is_input = i == 0 && op->inputnames[0] == "input";
2259+
if (!op->inputnames[i].empty() && !is_input)
22562260
continue;
22572261

22582262
fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str());
@@ -2262,7 +2266,8 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath, con
22622266

22632267
for (size_t i = 0; i < op->inputs.size(); i++)
22642268
{
2265-
if (op->inputnames[i].empty())
2269+
bool is_input = i == 0 && op->inputnames[0] == "input";
2270+
if (op->inputnames[i].empty() || is_input)
22662271
continue;
22672272

22682273
fprintf(pyfp, "%s=v_%s", op->inputnames[i].c_str(), sanitize_identifier(op->inputs[i]->name).c_str());
@@ -2416,10 +2421,8 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath, con
24162421
fprintf(pyfp, "(");
24172422
for (size_t i = 0; i < param.af.size(); i++)
24182423
{
2419-
if (param.af[i] == (int)param.af[i])
2420-
fprintf(pyfp, "%.1f", param.af[i]);
2421-
else
2422-
fprintf(pyfp, "%g", param.af[i]);
2424+
std::string afs = float_to_string(param.af[i]);
2425+
fprintf(pyfp, "%s", afs.c_str());
24232426
if (i + 1 != param.af.size() || param.af.size() == 1)
24242427
fprintf(pyfp, ",");
24252428
}

tools/pnnx/src/pass_level2.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <unordered_map>
1010
#include <unordered_set>
1111

12+
#include "pass_level2/eliminate_contiguous.h"
1213
#include "pass_level2/eliminate_size_numtotensor_int.h"
1314
#include "pass_level2/functionize.h"
1415
#include "pass_level2/fuse_constantlist.h"
@@ -1137,6 +1138,8 @@ void pass_level2(Graph& g)
11371138
{
11381139
functionize(g);
11391140

1141+
eliminate_contiguous(g);
1142+
11401143
eliminate_size_numtotensor_int(g);
11411144

11421145
fuse_constantlist(g);

tools/pnnx/src/pass_level2/F_hardshrink.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,34 @@ pnnx.Output output 1 0 out
122122

123123
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardshrink_onnx, 100)
124124

125+
class F_hardshrink_onnx_1 : public GraphRewriterPass
126+
{
127+
public:
128+
const char* match_pattern_graph() const
129+
{
130+
return R"PNNXIR(7767517
131+
7 6
132+
pnnx.Input input 0 1 input
133+
aten::abs op_0 1 1 input pnnx_8
134+
prim::Constant op_1 0 1 val_2 value=%lambd
135+
torch.le op_2 2 1 pnnx_8 val_2 pnnx_9
136+
prim::Constant op_3 0 1 scalar_tensor_default_8 value=0.0
137+
torch.where op_4 3 1 pnnx_9 scalar_tensor_default_8 input out
138+
pnnx.Output output 1 0 out
139+
)PNNXIR";
140+
}
141+
142+
const char* type_str() const
143+
{
144+
return "F.hardshrink";
145+
}
146+
147+
void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
148+
{
149+
op->params["lambd"] = captured_params.at("lambd");
150+
}
151+
};
152+
153+
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardshrink_onnx_1, 100)
154+
125155
} // namespace pnnx

tools/pnnx/src/pass_level2/F_hardsigmoid.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,4 +337,49 @@ pnnx.Output output 1 0 out
337337

338338
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardsigmoid_onnx_1, 101)
339339

340+
class F_hardsigmoid_onnx_2 : public GraphRewriterPass
341+
{
342+
public:
343+
const char* match_pattern_graph() const
344+
{
345+
return R"PNNXIR(7767517
346+
7 6
347+
pnnx.Input input 0 1 input
348+
prim::Constant op_0 0 1 scalar_tensor_default_8 value=3.0
349+
aten::add op_1 2 1 input scalar_tensor_default_8 pnnx_8
350+
torch.clamp op_2 1 1 pnnx_8 pnnx_9 max=6.0 min=0.0
351+
prim::Constant op_3 0 1 max_val_cast value=6.0
352+
aten::div op_4 2 1 pnnx_9 max_val_cast out
353+
pnnx.Output output 1 0 out
354+
)PNNXIR";
355+
}
356+
357+
const char* type_str() const
358+
{
359+
return "F.hardsigmoid";
360+
}
361+
};
362+
363+
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardsigmoid_onnx_2, 100)
364+
365+
class F_hardsigmoid_onnx_3 : public F_hardsigmoid_2_2
366+
{
367+
public:
368+
const char* match_pattern_graph() const
369+
{
370+
return R"PNNXIR(7767517
371+
7 6
372+
pnnx.Input input 0 1 input
373+
prim::Constant op_0 0 1 scalar_tensor_default_8_pnnxshadow3 value=3.0
374+
aten::add op_1 2 1 input scalar_tensor_default_8_pnnxshadow3 pnnx_16
375+
torch.clamp op_2 1 1 pnnx_16 pnnx_17 max=6.0 min=0.0
376+
prim::Constant op_3 0 1 val_12 value=%v1p6
377+
aten::mul op_4 2 1 pnnx_17 val_12 out
378+
pnnx.Output output 1 0 out
379+
)PNNXIR";
380+
}
381+
};
382+
383+
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardsigmoid_onnx_3, 100)
384+
340385
} // namespace pnnx

tools/pnnx/src/pass_level2/F_linear.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,41 @@ pnnx.Output output 1 0 out
268268

269269
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_linear_onnx_4, 110)
270270

271+
class F_linear_onnx_5 : public GraphRewriterPass
272+
{
273+
public:
274+
const char* match_pattern_graph() const
275+
{
276+
return R"PNNXIR(7767517
277+
5 4
278+
pnnx.Input input 0 1 input
279+
pnnx.Attribute weight 0 1 weight @data=(%out_features,%in_features)f32
280+
torch.transpose op_0 1 1 weight t dim0=1 dim1=0
281+
torch.mm op_1 2 1 input t out
282+
pnnx.Output output 1 0 out
283+
)PNNXIR";
284+
}
285+
286+
const char* replace_pattern_graph() const
287+
{
288+
return R"PNNXIR(7767517
289+
4 3
290+
pnnx.Input input 0 1 input
291+
pnnx.Attribute weight 0 1 weight
292+
F.linear linear 2 1 input weight out bias=None $weight=weight
293+
pnnx.Output output 1 0 out
294+
)PNNXIR";
295+
}
296+
297+
void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
298+
{
299+
Operator* op_weight = ops.at("weight");
300+
op_weight->attrs["data"] = captured_attrs.at("weight.data");
301+
}
302+
};
303+
304+
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_linear_onnx_5, 110)
305+
271306
class F_linear_tnn : public GraphRewriterPass
272307
{
273308
public:

0 commit comments

Comments
 (0)