2323from aesara .raise_op import Assert
2424from aesara .scalar import int32 as int_t
2525from aesara .scalar import upcast
26+ from aesara .scalar .basic import Composite
2627from aesara .tensor import basic as at
2728from aesara .tensor import get_vector_length
2829from aesara .tensor .exceptions import NotScalarConstantError
@@ -1552,16 +1553,29 @@ def broadcast_shape_iter(
15521553 # be broadcastable or equal to the one non-broadcastable
15531554 # constant `const_nt_shape_var`.
15541555 assert_dim = Assert ("Could not broadcast dimensions" )
1556+
1557+ scalar_nonconst_nb_shapes = [
1558+ at .scalar_from_tensor (s ) if isinstance (s , TensorVariable ) else s
1559+ for s in nonconst_nb_shapes
1560+ ]
1561+
1562+ dummy_nonconst_nb_shapes = [
1563+ v .type () for v in scalar_nonconst_nb_shapes
1564+ ]
15551565 assert_cond = reduce (
15561566 aes .and_ ,
15571567 (
15581568 aes .or_ (
15591569 aes .eq (nbv , one_at ), aes .eq (nbv , const_nt_shape_var )
15601570 )
1561- for nbv in nonconst_nb_shapes
1571+ for nbv in dummy_nonconst_nb_shapes
15621572 ),
15631573 )
1564- bcast_dim = assert_dim (const_nt_shape_var , assert_cond )
1574+ assert_cond_op = Composite (dummy_nonconst_nb_shapes , [assert_cond ])
1575+
1576+ bcast_dim = assert_dim (
1577+ const_nt_shape_var , assert_cond_op (* scalar_nonconst_nb_shapes )
1578+ )
15651579 else :
15661580 bcast_dim = const_nt_shape_var
15671581 else :
@@ -1579,21 +1593,36 @@ def broadcast_shape_iter(
15791593 result_dims .append (maybe_non_bcast_shapes [0 ])
15801594 continue
15811595
1596+ scalar_maybe_non_bcast_shapes = [
1597+ at .scalar_from_tensor (s ) if isinstance (s , TensorVariable ) else s
1598+ for s in maybe_non_bcast_shapes
1599+ ]
1600+ dummy_maybe_non_bcast_shapes = [
1601+ v .type () for v in scalar_maybe_non_bcast_shapes
1602+ ]
15821603 non_bcast_vec = [
15831604 aes .switch (aes .eq (nbv , 1 ), - one_at , nbv )
1584- for nbv in maybe_non_bcast_shapes
1605+ for nbv in dummy_maybe_non_bcast_shapes
15851606 ]
15861607 dim_max = aes .abs (reduce (aes .scalar_maximum , non_bcast_vec ))
1608+ dim_max_op = Composite (dummy_maybe_non_bcast_shapes , [dim_max ])
1609+
1610+ dummy_dim_max = dim_max_op (* dummy_maybe_non_bcast_shapes )
15871611
15881612 assert_dim = Assert ("Could not broadcast dimensions" )
15891613 assert_cond = reduce (
15901614 aes .and_ ,
15911615 (
1592- aes .or_ (aes .eq (nbv , - one_at ), aes .eq (nbv , dim_max ))
1616+ aes .or_ (aes .eq (nbv , - one_at ), aes .eq (nbv , dummy_dim_max ))
15931617 for nbv in non_bcast_vec
15941618 ),
15951619 )
1596- bcast_dim = assert_dim (dim_max , assert_cond )
1620+ assert_cond_op = Composite (dummy_maybe_non_bcast_shapes , [assert_cond ])
1621+
1622+ bcast_dim = assert_dim (
1623+ dim_max_op (* scalar_maybe_non_bcast_shapes ),
1624+ assert_cond_op (* scalar_maybe_non_bcast_shapes ),
1625+ )
15971626
15981627 result_dims .append (bcast_dim )
15991628
0 commit comments