@@ -40,24 +40,15 @@ def forward(self, x, y, z, w):
4040
4141 w = F .interpolate (w , scale_factor = (2.976744 ,2.976744 ), mode = 'nearest' , recompute_scale_factor = False )
4242
43- if version .parse (torch .__version__ ) >= version .parse ('1.11' ):
44- x = F .interpolate (x , size = 12 , mode = 'nearest-exact' ) + 2
45- x = F .interpolate (x , scale_factor = (3 ), mode = 'nearest-exact' )
46- y = F .interpolate (y , size = (11 ,12 ), mode = 'nearest-exact' ) + 3
47- y = F .interpolate (y , scale_factor = (3 ,2 ), mode = 'nearest-exact' )
48- z = F .interpolate (z , size = (11 ,12 ,13 ), mode = 'nearest-exact' ) + 4
49- z = F .interpolate (z , scale_factor = (3 ,1 ,2 ), mode = 'nearest-exact' )
50-
51- return x0 , x1 , x2 , y0 , y1 , y2 , y3 , z0 , z1 , z2 , z3 , w , x , y , z
52- else :
53- return x0 , x1 , x2 , y0 , y1 , y2 , y3 , z0 , z1 , z2 , z3 , w
43+ return x0 , x1 , x2 , y0 , y1 , y2 , y3 , z0 , z1 , z2 , z3 , w
5444 else :
5545 x = F .interpolate (x , size = 16 )
5646 x = F .interpolate (x , scale_factor = 2 , mode = 'nearest' )
5747 x = F .interpolate (x , size = (20 ), mode = 'nearest' )
5848 x = F .interpolate (x , scale_factor = (4 ), mode = 'nearest' )
59- x = F .interpolate (x , size = 12 , mode = 'nearest-exact' ) + 2
60- x = F .interpolate (x , scale_factor = (3 ), mode = 'nearest-exact' )
49+ if version .parse (torch .__version__ ) >= version .parse ('2.9' ):
50+ x = F .interpolate (x , size = 12 , mode = 'nearest-exact' ) + 2
51+ x = F .interpolate (x , scale_factor = (3 ), mode = 'nearest-exact' )
6152 x = F .interpolate (x , size = 16 , mode = 'linear' )
6253 x = F .interpolate (x , scale_factor = 2 , mode = 'linear' )
6354 x = F .interpolate (x , size = (24 ), mode = 'linear' , align_corners = True )
@@ -73,8 +64,9 @@ def forward(self, x, y, z, w):
7364 y = F .interpolate (y , scale_factor = (4 ,4 ), mode = 'nearest' )
7465 y = F .interpolate (y , size = (16 ,24 ), mode = 'nearest' )
7566 y = F .interpolate (y , scale_factor = (2 ,3 ), mode = 'nearest' )
76- y = F .interpolate (y , size = (11 ,12 ), mode = 'nearest-exact' ) + 3
77- y = F .interpolate (y , scale_factor = (3 ,2 ), mode = 'nearest-exact' )
67+ if version .parse (torch .__version__ ) >= version .parse ('2.9' ):
68+ y = F .interpolate (y , size = (11 ,12 ), mode = 'nearest-exact' ) + 3
69+ y = F .interpolate (y , scale_factor = (3 ,2 ), mode = 'nearest-exact' )
7870 y = F .interpolate (y , size = 16 , mode = 'bilinear' )
7971 y = F .interpolate (y , scale_factor = 2 , mode = 'bilinear' )
8072 y = F .interpolate (y , size = (20 ,20 ), mode = 'bilinear' , align_corners = False )
@@ -101,8 +93,9 @@ def forward(self, x, y, z, w):
10193 z = F .interpolate (z , scale_factor = (4 ,4 ,4 ), mode = 'nearest' )
10294 z = F .interpolate (z , size = (16 ,24 ,20 ), mode = 'nearest' )
10395 z = F .interpolate (z , scale_factor = (2 ,3 ,4 ), mode = 'nearest' )
104- z = F .interpolate (z , size = (11 ,12 ,13 ), mode = 'nearest-exact' ) + 4
105- z = F .interpolate (z , scale_factor = (3 ,1 ,2 ), mode = 'nearest-exact' )
96+ if version .parse (torch .__version__ ) >= version .parse ('2.9' ):
97+ z = F .interpolate (z , size = (11 ,12 ,13 ), mode = 'nearest-exact' ) + 4
98+ z = F .interpolate (z , scale_factor = (3 ,1 ,2 ), mode = 'nearest-exact' )
10699 z = F .interpolate (z , size = 16 , mode = 'trilinear' )
107100 z = F .interpolate (z , scale_factor = 2 , mode = 'trilinear' )
108101 z = F .interpolate (z , size = (20 ,20 ,20 ), mode = 'trilinear' , align_corners = False )
0 commit comments