Skip to content

Commit 9664432

Browse files
committed
standardize_shape normalizes the dimensions and tuple.
- `standardize_shape` would allow int-like objects (e.g. NumPy scalars) to be in the shape, but would not resolve them to actual ints. This could cause issues later. With this change, objects that can be cast to int will be cast to int, which includes NumPy scalars, but also TensorFlow constant or eager tensors. - `standardize_shape` would have custom code to handle `torch.Size`. Generalized it to turn anything iterable to a plain tuple. - Added units tests - Removed duplicate unit tests - Added TensorFlow specific unit test
1 parent 9bcdbc7 commit 9664432

File tree

2 files changed

+81
-142
lines changed

2 files changed

+81
-142
lines changed

keras/src/backend/common/variables.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,6 @@ def standardize_shape(shape):
599599
# `tf.TensorShape` may contain `Dimension` objects.
600600
# We need to convert the items in it to either int or `None`
601601
shape = shape.as_list()
602-
shape = tuple(shape)
603602

604603
if config.backend() == "jax":
605604
# Replace `_DimExpr` (dimension expression) with None
@@ -609,25 +608,37 @@ def standardize_shape(shape):
609608
None if jax_export.is_symbolic_dim(d) else d for d in shape
610609
)
611610

612-
if config.backend() == "torch":
613-
# `shape` might be `torch.Size`. We need to convert the items in it to
614-
# either int or `None`
615-
shape = tuple(map(lambda x: int(x) if x is not None else None, shape))
616-
617-
for e in shape:
618-
if e is None:
611+
# Handle dimensions that are not ints and not None, verify they're >= 0.
612+
standardized_shape = []
613+
for d in shape:
614+
if d is None:
615+
standardized_shape.append(d)
619616
continue
620-
if not is_int_dtype(type(e)):
617+
618+
# Reject these even if they can be cast to int successfully.
619+
if isinstance(d, (str, float)):
621620
raise ValueError(
622621
f"Cannot convert '{shape}' to a shape. "
623-
f"Found invalid entry '{e}' of type '{type(e)}'. "
622+
f"Found invalid dimension '{d}' of type '{type(d)}'. "
624623
)
625-
if e < 0:
624+
625+
try:
626+
# Cast numpy scalars, tf constant tensors, etc.
627+
d = int(d)
628+
except Exception as e:
629+
raise ValueError(
630+
f"Cannot convert '{shape}' to a shape. "
631+
f"Found invalid dimension '{d}' of type '{type(d)}'. "
632+
) from e
633+
if d < 0:
626634
raise ValueError(
627635
f"Cannot convert '{shape}' to a shape. "
628636
"Negative dimensions are not allowed."
629637
)
630-
return shape
638+
standardized_shape.append(d)
639+
640+
# This also turns subclasses of `tuple` (e.g. `torch.Size`) to plain tuple.
641+
return tuple(standardized_shape)
631642

632643

633644
def shape_equal(a_shape, b_shape):

keras/src/backend/common/variables_test.py

Lines changed: 58 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -310,32 +310,69 @@ def test_name_validation(self):
310310
)
311311

312312
def test_standardize_shape_with_none(self):
313-
"""Tests standardizing shape with None."""
314313
with self.assertRaisesRegex(
315314
ValueError, "Undefined shapes are not supported."
316315
):
317316
standardize_shape(None)
318317

319318
def test_standardize_shape_with_non_iterable(self):
320-
"""Tests shape standardization with non-iterables."""
321319
with self.assertRaisesRegex(
322320
ValueError, "Cannot convert '42' to a shape."
323321
):
324322
standardize_shape(42)
325323

326324
def test_standardize_shape_with_valid_input(self):
327-
"""Tests standardizing shape with valid input."""
325+
shape = (3, 4, 5)
326+
standardized_shape = standardize_shape(shape)
327+
self.assertEqual(standardized_shape, (3, 4, 5))
328+
329+
def test_standardize_shape_with_valid_inputWith_none(self):
330+
shape = (3, None, 5)
331+
standardized_shape = standardize_shape(shape)
332+
self.assertEqual(standardized_shape, (3, None, 5))
333+
334+
def test_standardize_shape_with_valid_not_tuple_input(self):
328335
shape = [3, 4, 5]
329336
standardized_shape = standardize_shape(shape)
330337
self.assertEqual(standardized_shape, (3, 4, 5))
331338

332-
def test_standardize_shape_with_negative_entry(self):
333-
"""Tests standardizing shape with negative entries."""
339+
def test_standardize_shape_with_numpy(self):
340+
shape = [3, np.int32(4), np.int64(5)]
341+
standardized_shape = standardize_shape(shape)
342+
self.assertEqual(standardized_shape, (3, 4, 5))
343+
for d in standardized_shape:
344+
self.assertIsInstance(d, int)
345+
346+
def test_standardize_shape_with_string(self):
347+
shape_with_string = (3, 4, "5")
348+
with self.assertRaisesRegex(
349+
ValueError,
350+
"Cannot convert .* to a shape. Found invalid dimension '5'.",
351+
):
352+
standardize_shape(shape_with_string)
353+
354+
def test_standardize_shape_with_float(self):
355+
shape_with_float = (3, 4, 5.0)
356+
with self.assertRaisesRegex(
357+
ValueError,
358+
"Cannot convert .* to a shape. Found invalid dimension '5.0'.",
359+
):
360+
standardize_shape(shape_with_float)
361+
362+
def test_standardize_shape_with_object(self):
363+
shape_with_float = (3, 4, object())
334364
with self.assertRaisesRegex(
335365
ValueError,
336-
"Cannot convert '\\(3, 4, -5\\)' to a shape. Negative dimensions",
366+
"Cannot convert .* to a shape. Found invalid dimension .*object",
337367
):
338-
standardize_shape([3, 4, -5])
368+
standardize_shape(shape_with_float)
369+
370+
def test_standardize_shape_with_negative_dimension(self):
371+
with self.assertRaisesRegex(
372+
ValueError,
373+
"Cannot convert .* to a shape. Negative dimensions",
374+
):
375+
standardize_shape((3, 4, -5))
339376

340377
def test_shape_equal_length_mismatch(self):
341378
"""Test mismatch in lengths of shapes."""
@@ -1138,138 +1175,29 @@ def test_xor(self, dtypes):
11381175
reason="Tests for standardize_shape with Torch backend",
11391176
)
11401177
class TestStandardizeShapeWithTorch(test_case.TestCase):
1141-
"""Tests for standardize_shape with Torch backend."""
1142-
1143-
def test_standardize_shape_with_torch_size_containing_negative_value(self):
1144-
"""Tests shape with a negative value."""
1145-
shape_with_negative_value = (3, 4, -5)
1146-
with self.assertRaisesRegex(
1147-
ValueError,
1148-
"Cannot convert '\\(3, 4, -5\\)' to a shape. Negative dimensions",
1149-
):
1150-
_ = standardize_shape(shape_with_negative_value)
1151-
1152-
def test_standardize_shape_with_torch_size_valid(self):
1153-
"""Tests a valid shape."""
1154-
shape_valid = (3, 4, 5)
1155-
standardized_shape = standardize_shape(shape_valid)
1156-
self.assertEqual(standardized_shape, (3, 4, 5))
1157-
1158-
def test_standardize_shape_with_torch_size_multidimensional(self):
1159-
"""Tests shape of a multi-dimensional tensor."""
1178+
def test_standardize_shape_with_torch_size(self):
11601179
import torch
11611180

11621181
tensor = torch.randn(3, 4, 5)
11631182
shape = tensor.size()
11641183
standardized_shape = standardize_shape(shape)
11651184
self.assertEqual(standardized_shape, (3, 4, 5))
1166-
1167-
def test_standardize_shape_with_torch_size_single_dimension(self):
1168-
"""Tests shape of a single-dimensional tensor."""
1169-
import torch
1170-
1171-
tensor = torch.randn(10)
1172-
shape = tensor.size()
1173-
standardized_shape = standardize_shape(shape)
1174-
self.assertEqual(standardized_shape, (10,))
1175-
1176-
def test_standardize_shape_with_torch_size_with_valid_1_dimension(self):
1177-
"""Tests a valid shape."""
1178-
shape_valid = [3]
1179-
standardized_shape = standardize_shape(shape_valid)
1180-
self.assertEqual(standardized_shape, (3,))
1181-
1182-
def test_standardize_shape_with_torch_size_with_valid_2_dimension(self):
1183-
"""Tests a valid shape."""
1184-
shape_valid = [3, 4]
1185-
standardized_shape = standardize_shape(shape_valid)
1186-
self.assertEqual(standardized_shape, (3, 4))
1187-
1188-
def test_standardize_shape_with_torch_size_with_valid_3_dimension(self):
1189-
"""Tests a valid shape."""
1190-
shape_valid = [3, 4, 5]
1191-
standardized_shape = standardize_shape(shape_valid)
1192-
self.assertEqual(standardized_shape, (3, 4, 5))
1193-
1194-
def test_standardize_shape_with_torch_size_with_negative_value(self):
1195-
"""Tests shape with a negative value appended."""
1196-
import torch
1197-
1198-
tensor = torch.randn(3, 4, 5)
1199-
shape = tuple(tensor.size())
1200-
shape_with_negative = shape + (-1,)
1201-
with self.assertRaisesRegex(
1202-
ValueError,
1203-
"Cannot convert .* to a shape. Negative dimensions are not",
1204-
):
1205-
_ = standardize_shape(shape_with_negative)
1206-
1207-
def test_standardize_shape_with_non_integer_entry(self):
1208-
"""Tests shape with a non-integer value."""
1209-
with self.assertRaisesRegex(
1210-
# different error message for torch
1211-
ValueError,
1212-
r"invalid literal for int\(\) with base 10: 'a'",
1213-
):
1214-
standardize_shape([3, 4, "a"])
1215-
1216-
def test_standardize_shape_with_negative_entry(self):
1217-
"""Tests shape with a negative value."""
1218-
with self.assertRaisesRegex(
1219-
ValueError,
1220-
"Cannot convert '\\(3, 4, -5\\)' to a shape. Negative dimensions",
1221-
):
1222-
standardize_shape([3, 4, -5])
1223-
1224-
def test_standardize_shape_with_valid_not_tuple(self):
1225-
"""Tests a valid shape."""
1226-
shape_valid = [3, 4, 5]
1227-
standardized_shape = standardize_shape(shape_valid)
1228-
self.assertEqual(standardized_shape, (3, 4, 5))
1185+
self.assertIs(type(standardized_shape), tuple)
1186+
for d in standardized_shape:
1187+
self.assertIsInstance(d, int)
12291188

12301189

12311190
@pytest.mark.skipif(
1232-
backend.backend() == "torch",
1233-
reason="Tests for standardize_shape with others backend",
1191+
backend.backend() != "tensorflow",
1192+
reason="Tests for standardize_shape with TensorFlow backend",
12341193
)
1235-
class TestStandardizeShapeWithOutTorch(test_case.TestCase):
1236-
"""Tests for standardize_shape with others backend."""
1194+
class TestStandardizeShapeWithTensorflow(test_case.TestCase):
1195+
def test_standardize_shape_with_tensor_size(self):
1196+
import tensorflow as tf
12371197

1238-
def test_standardize_shape_with_out_torch_negative_value(self):
1239-
"""Tests shape with a negative value."""
1240-
shape_with_negative_value = (3, 4, -5)
1241-
with self.assertRaisesRegex(
1242-
ValueError,
1243-
"Cannot convert '\\(3, 4, -5\\)' to a shape. Negative dimensions",
1244-
):
1245-
_ = standardize_shape(shape_with_negative_value)
1246-
1247-
def test_standardize_shape_with_out_torch_string(self):
1248-
"""Tests shape with a string value."""
1249-
shape_with_string = (3, 4, "5")
1250-
with self.assertRaisesRegex(
1251-
ValueError,
1252-
"Cannot convert .* to a shape. Found invalid entry '5'.",
1253-
):
1254-
_ = standardize_shape(shape_with_string)
1255-
1256-
def test_standardize_shape_with_out_torch_float(self):
1257-
"""Tests shape with a float value."""
1258-
shape_with_float = (3, 4, 5.0)
1259-
with self.assertRaisesRegex(
1260-
ValueError,
1261-
"Cannot convert .* to a shape. Found invalid entry '5.0'.",
1262-
):
1263-
_ = standardize_shape(shape_with_float)
1264-
1265-
def test_standardize_shape_with_out_torch_valid(self):
1266-
"""Tests a valid shape."""
1267-
shape_valid = (3, 4, 5)
1268-
standardized_shape = standardize_shape(shape_valid)
1269-
self.assertEqual(standardized_shape, (3, 4, 5))
1270-
1271-
def test_standardize_shape_with_out_torch_valid_not_tuple(self):
1272-
"""Tests a valid shape."""
1273-
shape_valid = [3, 4, 5]
1274-
standardized_shape = standardize_shape(shape_valid)
1198+
shape = (3, tf.constant(4, dtype=tf.int64), 5)
1199+
standardized_shape = standardize_shape(shape)
12751200
self.assertEqual(standardized_shape, (3, 4, 5))
1201+
self.assertIs(type(standardized_shape), tuple)
1202+
for d in standardized_shape:
1203+
self.assertIsInstance(d, int)

0 commit comments

Comments
 (0)