@@ -310,32 +310,69 @@ def test_name_validation(self):
310310 )
311311
312312 def test_standardize_shape_with_none (self ):
313- """Tests standardizing shape with None."""
314313 with self .assertRaisesRegex (
315314 ValueError , "Undefined shapes are not supported."
316315 ):
317316 standardize_shape (None )
318317
319318 def test_standardize_shape_with_non_iterable (self ):
320- """Tests shape standardization with non-iterables."""
321319 with self .assertRaisesRegex (
322320 ValueError , "Cannot convert '42' to a shape."
323321 ):
324322 standardize_shape (42 )
325323
326324 def test_standardize_shape_with_valid_input (self ):
327- """Tests standardizing shape with valid input."""
325+ shape = (3 , 4 , 5 )
326+ standardized_shape = standardize_shape (shape )
327+ self .assertEqual (standardized_shape , (3 , 4 , 5 ))
328+
329+ def test_standardize_shape_with_valid_inputWith_none (self ):
330+ shape = (3 , None , 5 )
331+ standardized_shape = standardize_shape (shape )
332+ self .assertEqual (standardized_shape , (3 , None , 5 ))
333+
334+ def test_standardize_shape_with_valid_not_tuple_input (self ):
328335 shape = [3 , 4 , 5 ]
329336 standardized_shape = standardize_shape (shape )
330337 self .assertEqual (standardized_shape , (3 , 4 , 5 ))
331338
332- def test_standardize_shape_with_negative_entry (self ):
333- """Tests standardizing shape with negative entries."""
339+ def test_standardize_shape_with_numpy (self ):
340+ shape = [3 , np .int32 (4 ), np .int64 (5 )]
341+ standardized_shape = standardize_shape (shape )
342+ self .assertEqual (standardized_shape , (3 , 4 , 5 ))
343+ for d in standardized_shape :
344+ self .assertIsInstance (d , int )
345+
346+ def test_standardize_shape_with_string (self ):
347+ shape_with_string = (3 , 4 , "5" )
348+ with self .assertRaisesRegex (
349+ ValueError ,
350+ "Cannot convert .* to a shape. Found invalid dimension '5'." ,
351+ ):
352+ standardize_shape (shape_with_string )
353+
354+ def test_standardize_shape_with_float (self ):
355+ shape_with_float = (3 , 4 , 5.0 )
356+ with self .assertRaisesRegex (
357+ ValueError ,
358+ "Cannot convert .* to a shape. Found invalid dimension '5.0'." ,
359+ ):
360+ standardize_shape (shape_with_float )
361+
362+ def test_standardize_shape_with_object (self ):
363+ shape_with_float = (3 , 4 , object ())
334364 with self .assertRaisesRegex (
335365 ValueError ,
336- "Cannot convert ' \\ (3, 4, -5 \\ )' to a shape. Negative dimensions " ,
366+ "Cannot convert .* to a shape. Found invalid dimension .*object " ,
337367 ):
338- standardize_shape ([3 , 4 , - 5 ])
368+ standardize_shape (shape_with_float )
369+
370+ def test_standardize_shape_with_negative_dimension (self ):
371+ with self .assertRaisesRegex (
372+ ValueError ,
373+ "Cannot convert .* to a shape. Negative dimensions" ,
374+ ):
375+ standardize_shape ((3 , 4 , - 5 ))
339376
340377 def test_shape_equal_length_mismatch (self ):
341378 """Test mismatch in lengths of shapes."""
@@ -1138,138 +1175,29 @@ def test_xor(self, dtypes):
11381175 reason = "Tests for standardize_shape with Torch backend" ,
11391176)
11401177class TestStandardizeShapeWithTorch (test_case .TestCase ):
1141- """Tests for standardize_shape with Torch backend."""
1142-
1143- def test_standardize_shape_with_torch_size_containing_negative_value (self ):
1144- """Tests shape with a negative value."""
1145- shape_with_negative_value = (3 , 4 , - 5 )
1146- with self .assertRaisesRegex (
1147- ValueError ,
1148- "Cannot convert '\\ (3, 4, -5\\ )' to a shape. Negative dimensions" ,
1149- ):
1150- _ = standardize_shape (shape_with_negative_value )
1151-
1152- def test_standardize_shape_with_torch_size_valid (self ):
1153- """Tests a valid shape."""
1154- shape_valid = (3 , 4 , 5 )
1155- standardized_shape = standardize_shape (shape_valid )
1156- self .assertEqual (standardized_shape , (3 , 4 , 5 ))
1157-
1158- def test_standardize_shape_with_torch_size_multidimensional (self ):
1159- """Tests shape of a multi-dimensional tensor."""
1178+ def test_standardize_shape_with_torch_size (self ):
11601179 import torch
11611180
11621181 tensor = torch .randn (3 , 4 , 5 )
11631182 shape = tensor .size ()
11641183 standardized_shape = standardize_shape (shape )
11651184 self .assertEqual (standardized_shape , (3 , 4 , 5 ))
1166-
1167- def test_standardize_shape_with_torch_size_single_dimension (self ):
1168- """Tests shape of a single-dimensional tensor."""
1169- import torch
1170-
1171- tensor = torch .randn (10 )
1172- shape = tensor .size ()
1173- standardized_shape = standardize_shape (shape )
1174- self .assertEqual (standardized_shape , (10 ,))
1175-
1176- def test_standardize_shape_with_torch_size_with_valid_1_dimension (self ):
1177- """Tests a valid shape."""
1178- shape_valid = [3 ]
1179- standardized_shape = standardize_shape (shape_valid )
1180- self .assertEqual (standardized_shape , (3 ,))
1181-
1182- def test_standardize_shape_with_torch_size_with_valid_2_dimension (self ):
1183- """Tests a valid shape."""
1184- shape_valid = [3 , 4 ]
1185- standardized_shape = standardize_shape (shape_valid )
1186- self .assertEqual (standardized_shape , (3 , 4 ))
1187-
1188- def test_standardize_shape_with_torch_size_with_valid_3_dimension (self ):
1189- """Tests a valid shape."""
1190- shape_valid = [3 , 4 , 5 ]
1191- standardized_shape = standardize_shape (shape_valid )
1192- self .assertEqual (standardized_shape , (3 , 4 , 5 ))
1193-
1194- def test_standardize_shape_with_torch_size_with_negative_value (self ):
1195- """Tests shape with a negative value appended."""
1196- import torch
1197-
1198- tensor = torch .randn (3 , 4 , 5 )
1199- shape = tuple (tensor .size ())
1200- shape_with_negative = shape + (- 1 ,)
1201- with self .assertRaisesRegex (
1202- ValueError ,
1203- "Cannot convert .* to a shape. Negative dimensions are not" ,
1204- ):
1205- _ = standardize_shape (shape_with_negative )
1206-
1207- def test_standardize_shape_with_non_integer_entry (self ):
1208- """Tests shape with a non-integer value."""
1209- with self .assertRaisesRegex (
1210- # different error message for torch
1211- ValueError ,
1212- r"invalid literal for int\(\) with base 10: 'a'" ,
1213- ):
1214- standardize_shape ([3 , 4 , "a" ])
1215-
1216- def test_standardize_shape_with_negative_entry (self ):
1217- """Tests shape with a negative value."""
1218- with self .assertRaisesRegex (
1219- ValueError ,
1220- "Cannot convert '\\ (3, 4, -5\\ )' to a shape. Negative dimensions" ,
1221- ):
1222- standardize_shape ([3 , 4 , - 5 ])
1223-
1224- def test_standardize_shape_with_valid_not_tuple (self ):
1225- """Tests a valid shape."""
1226- shape_valid = [3 , 4 , 5 ]
1227- standardized_shape = standardize_shape (shape_valid )
1228- self .assertEqual (standardized_shape , (3 , 4 , 5 ))
1185+ self .assertIs (type (standardized_shape ), tuple )
1186+ for d in standardized_shape :
1187+ self .assertIsInstance (d , int )
12291188
12301189
12311190@pytest .mark .skipif (
1232- backend .backend () == "torch " ,
1233- reason = "Tests for standardize_shape with others backend" ,
1191+ backend .backend () != "tensorflow " ,
1192+ reason = "Tests for standardize_shape with TensorFlow backend" ,
12341193)
1235- class TestStandardizeShapeWithOutTorch (test_case .TestCase ):
1236- """Tests for standardize_shape with others backend."""
1194+ class TestStandardizeShapeWithTensorflow (test_case .TestCase ):
1195+ def test_standardize_shape_with_tensor_size (self ):
1196+ import tensorflow as tf
12371197
1238- def test_standardize_shape_with_out_torch_negative_value (self ):
1239- """Tests shape with a negative value."""
1240- shape_with_negative_value = (3 , 4 , - 5 )
1241- with self .assertRaisesRegex (
1242- ValueError ,
1243- "Cannot convert '\\ (3, 4, -5\\ )' to a shape. Negative dimensions" ,
1244- ):
1245- _ = standardize_shape (shape_with_negative_value )
1246-
1247- def test_standardize_shape_with_out_torch_string (self ):
1248- """Tests shape with a string value."""
1249- shape_with_string = (3 , 4 , "5" )
1250- with self .assertRaisesRegex (
1251- ValueError ,
1252- "Cannot convert .* to a shape. Found invalid entry '5'." ,
1253- ):
1254- _ = standardize_shape (shape_with_string )
1255-
1256- def test_standardize_shape_with_out_torch_float (self ):
1257- """Tests shape with a float value."""
1258- shape_with_float = (3 , 4 , 5.0 )
1259- with self .assertRaisesRegex (
1260- ValueError ,
1261- "Cannot convert .* to a shape. Found invalid entry '5.0'." ,
1262- ):
1263- _ = standardize_shape (shape_with_float )
1264-
1265- def test_standardize_shape_with_out_torch_valid (self ):
1266- """Tests a valid shape."""
1267- shape_valid = (3 , 4 , 5 )
1268- standardized_shape = standardize_shape (shape_valid )
1269- self .assertEqual (standardized_shape , (3 , 4 , 5 ))
1270-
1271- def test_standardize_shape_with_out_torch_valid_not_tuple (self ):
1272- """Tests a valid shape."""
1273- shape_valid = [3 , 4 , 5 ]
1274- standardized_shape = standardize_shape (shape_valid )
1198+ shape = (3 , tf .constant (4 , dtype = tf .int64 ), 5 )
1199+ standardized_shape = standardize_shape (shape )
12751200 self .assertEqual (standardized_shape , (3 , 4 , 5 ))
1201+ self .assertIs (type (standardized_shape ), tuple )
1202+ for d in standardized_shape :
1203+ self .assertIsInstance (d , int )
0 commit comments