Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 9b7021c

Browse files
Use Composite graphs in aesara.tensor.extra_ops.broadcast_shape_iter
1 parent 63ca73d commit 9b7021c

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

aesara/tensor/extra_ops.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from aesara.raise_op import Assert
2424
from aesara.scalar import int32 as int_t
2525
from aesara.scalar import upcast
26+
from aesara.scalar.basic import Composite
2627
from aesara.tensor import basic as at
2728
from aesara.tensor import get_vector_length
2829
from aesara.tensor.exceptions import NotScalarConstantError
@@ -1552,16 +1553,29 @@ def broadcast_shape_iter(
15521553
# be broadcastable or equal to the one non-broadcastable
15531554
# constant `const_nt_shape_var`.
15541555
assert_dim = Assert("Could not broadcast dimensions")
1556+
1557+
scalar_nonconst_nb_shapes = [
1558+
at.scalar_from_tensor(s) if isinstance(s, TensorVariable) else s
1559+
for s in nonconst_nb_shapes
1560+
]
1561+
1562+
dummy_nonconst_nb_shapes = [
1563+
v.type() for v in scalar_nonconst_nb_shapes
1564+
]
15551565
assert_cond = reduce(
15561566
aes.and_,
15571567
(
15581568
aes.or_(
15591569
aes.eq(nbv, one_at), aes.eq(nbv, const_nt_shape_var)
15601570
)
1561-
for nbv in nonconst_nb_shapes
1571+
for nbv in dummy_nonconst_nb_shapes
15621572
),
15631573
)
1564-
bcast_dim = assert_dim(const_nt_shape_var, assert_cond)
1574+
assert_cond_op = Composite(dummy_nonconst_nb_shapes, [assert_cond])
1575+
1576+
bcast_dim = assert_dim(
1577+
const_nt_shape_var, assert_cond_op(*scalar_nonconst_nb_shapes)
1578+
)
15651579
else:
15661580
bcast_dim = const_nt_shape_var
15671581
else:
@@ -1579,21 +1593,36 @@ def broadcast_shape_iter(
15791593
result_dims.append(maybe_non_bcast_shapes[0])
15801594
continue
15811595

1596+
scalar_maybe_non_bcast_shapes = [
1597+
at.scalar_from_tensor(s) if isinstance(s, TensorVariable) else s
1598+
for s in maybe_non_bcast_shapes
1599+
]
1600+
dummy_maybe_non_bcast_shapes = [
1601+
v.type() for v in scalar_maybe_non_bcast_shapes
1602+
]
15821603
non_bcast_vec = [
15831604
aes.switch(aes.eq(nbv, 1), -one_at, nbv)
1584-
for nbv in maybe_non_bcast_shapes
1605+
for nbv in dummy_maybe_non_bcast_shapes
15851606
]
15861607
dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec))
1608+
dim_max_op = Composite(dummy_maybe_non_bcast_shapes, [dim_max])
1609+
1610+
dummy_dim_max = dim_max_op(*dummy_maybe_non_bcast_shapes)
15871611

15881612
assert_dim = Assert("Could not broadcast dimensions")
15891613
assert_cond = reduce(
15901614
aes.and_,
15911615
(
1592-
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dim_max))
1616+
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dummy_dim_max))
15931617
for nbv in non_bcast_vec
15941618
),
15951619
)
1596-
bcast_dim = assert_dim(dim_max, assert_cond)
1620+
assert_cond_op = Composite(dummy_maybe_non_bcast_shapes, [assert_cond])
1621+
1622+
bcast_dim = assert_dim(
1623+
dim_max_op(*scalar_maybe_non_bcast_shapes),
1624+
assert_cond_op(*scalar_maybe_non_bcast_shapes),
1625+
)
15971626

15981627
result_dims.append(bcast_dim)
15991628

0 commit comments

Comments
 (0)