Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 03be691

Browse files
committed
Prepare switch from class to metaclass for Type
1 parent 94c6a00 commit 03be691

File tree

104 files changed

+1008
-857
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

104 files changed

+1008
-857
lines changed

aesara/breakpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def perform(self, node, inputs, output_storage):
143143
output_storage[i][0] = inputs[i + 1]
144144

145145
def grad(self, inputs, output_gradients):
146-
return [DisconnectedType()()] + output_gradients
146+
return [DisconnectedType.subtype()()] + output_gradients
147147

148148
def infer_shape(self, fgraph, inputs, input_shapes):
149149
# Return the shape of every input but the condition (first input)

aesara/gradient.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def grad_not_implemented(op, x_pos, x, comment=""):
9090
"""
9191

9292
return (
93-
NullType(
93+
NullType.subtype(
9494
(
9595
"This variable is Null because the grad method for "
9696
f"input {x_pos} ({x}) of the {op} op is not implemented. {comment}"
@@ -113,7 +113,7 @@ def grad_undefined(op, x_pos, x, comment=""):
113113
"""
114114

115115
return (
116-
NullType(
116+
NullType.subtype(
117117
(
118118
"This variable is Null because the grad method for "
119119
f"input {x_pos} ({x}) of the {op} op is not implemented. {comment}"
@@ -158,7 +158,7 @@ def __str__(self):
158158
return "DisconnectedType"
159159

160160

161-
disconnected_type = DisconnectedType()
161+
disconnected_type = DisconnectedType.subtype()
162162

163163

164164
def Rop(
@@ -1803,7 +1803,7 @@ def verify_grad(
18031803
)
18041804

18051805
tensor_pt = [
1806-
aesara.tensor.type.TensorType(
1806+
aesara.tensor.type.TensorType.subtype(
18071807
aesara.tensor.as_tensor_variable(p).dtype,
18081808
aesara.tensor.as_tensor_variable(p).broadcastable,
18091809
)(name=f"input {i}")

aesara/graph/null_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,4 @@ def __str__(self):
4242
return "NullType"
4343

4444

45-
null_type = NullType()
45+
null_type = NullType.subtype()

aesara/graph/type.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,23 @@
55

66
from aesara.graph import utils
77
from aesara.graph.basic import Constant, Variable
8-
from aesara.graph.utils import MetaObject
8+
from aesara.graph.utils import MetaType
99

1010

1111
D = TypeVar("D")
1212

1313

14-
class Type(MetaObject, Generic[D]):
14+
class NewTypeMeta(type):
15+
# pass
16+
def __call__(cls, *args, **kwargs):
17+
raise RuntimeError("Use subtype")
18+
# return super().__call__(*args, **kwargs)
19+
20+
def subtype(cls, *args, **kwargs):
21+
return super().__call__(*args, **kwargs)
22+
23+
24+
class Type(Generic[D], metaclass=NewTypeMeta):
1525
"""
1626
Interface specification for variable type instances.
1727
@@ -35,6 +45,12 @@ class Type(MetaObject, Generic[D]):
3545
The `Type` that will be created by a call to `Type.make_constant`.
3646
"""
3747

48+
__props__: tuple[str, ...] = ()
49+
50+
@classmethod
51+
def create(cls, **kwargs):
52+
MetaType(f"{cls.__name__}[{kwargs}]", (cls,), kwargs)
53+
3854
def in_same_class(self, otype: "Type") -> Optional[bool]:
3955
"""Determine if another `Type` represents a subset from the same "class" of types represented by `self`.
4056
@@ -214,7 +230,7 @@ def make_constant(self, value: D, name: Optional[Text] = None) -> constant_type:
214230

215231
def clone(self, *args, **kwargs) -> "Type":
216232
"""Clone a copy of this type with the given arguments/keyword values, if any."""
217-
return type(self)(*args, **kwargs)
233+
return type(self).subtype(*args, **kwargs)
218234

219235
def __call__(self, name: Optional[Text] = None) -> variable_type:
220236
"""Return a new `Variable` instance of Type `self`.
@@ -261,6 +277,41 @@ def values_eq_approx(cls, a: D, b: D) -> bool:
261277
"""
262278
return cls.values_eq(a, b)
263279

280+
def _props(self):
281+
"""
282+
Tuple of properties of all attributes
283+
"""
284+
return tuple(getattr(self, a) for a in self.__props__)
285+
286+
def _props_dict(self):
287+
"""This return a dict of all ``__props__`` key-> value.
288+
289+
This is useful in optimization to swap op that should have the
290+
same props. This help detect error that the new op have at
291+
least all the original props.
292+
293+
"""
294+
return {a: getattr(self, a) for a in self.__props__}
295+
296+
def __hash__(self):
297+
return hash((type(self), tuple(getattr(self, a) for a in self.__props__)))
298+
299+
def __eq__(self, other):
300+
return type(self) == type(other) and tuple(
301+
getattr(self, a) for a in self.__props__
302+
) == tuple(getattr(other, a) for a in self.__props__)
303+
304+
def __str__(self):
305+
if self.__props__ is None or len(self.__props__) == 0:
306+
return f"{self.__class__.__name__}()"
307+
else:
308+
return "{}{{{}}}".format(
309+
self.__class__.__name__,
310+
", ".join(
311+
"{}={!r}".format(p, getattr(self, p)) for p in self.__props__
312+
),
313+
)
314+
264315

265316
DataType = str
266317

aesara/link/c/params_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ def extended(self, **kwargs):
626626
"""
627627
self_to_dict = {self.fields[i]: self.types[i] for i in range(self.length)}
628628
self_to_dict.update(kwargs)
629-
return ParamsType(**self_to_dict)
629+
return ParamsType.subtype(**self_to_dict)
630630

631631
# Returns a Params object with expected attributes or (in strict mode) checks that data has expected attributes.
632632
def filter(self, data, strict=False, allow_downcast=None):

aesara/link/c/type.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def __str__(self):
115115
return self.__class__.__name__
116116

117117

118-
generic = Generic()
118+
generic = Generic.subtype()
119119

120120
_cdata_type = None
121121

@@ -497,7 +497,10 @@ def __repr__(self):
497497
def __getattr__(self, key):
498498
if key in self:
499499
return self[key]
500-
return CType.__getattr__(self, key)
500+
else:
501+
raise AttributeError(
502+
f"{self.__class__.__name__} object has no attribute or enum value {key}"
503+
)
501504

502505
def __setattr__(self, key, value):
503506
if key in self:
@@ -530,6 +533,9 @@ def __eq__(self, other):
530533
and all(self.aliases[a] == other.aliases[a] for a in self.aliases)
531534
)
532535

536+
def __ne__(self, other):
537+
return not self == other
538+
533539
# EnumType should be used to create constants available in both Python and C code.
534540
# However, for convenience, we make sure EnumType can have a value, like other common types,
535541
# such that it could be used as-is as an op param.

aesara/raise_op.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __hash__(self):
2222
return hash(type(self))
2323

2424

25-
exception_type = ExceptionType()
25+
exception_type = ExceptionType.subtype()
2626

2727

2828
class CheckAndRaise(COp):
@@ -38,7 +38,7 @@ class CheckAndRaise(COp):
3838
view_map = {0: [0]}
3939

4040
check_input = False
41-
params_type = ParamsType(exc_type=exception_type)
41+
params_type = ParamsType.subtype(exc_type=exception_type)
4242

4343
def __init__(self, exc_type, msg=""):
4444

@@ -100,7 +100,7 @@ def perform(self, node, inputs, outputs, params):
100100
raise self.exc_type(self.msg)
101101

102102
def grad(self, input, output_gradients):
103-
return output_gradients + [DisconnectedType()()] * (len(input) - 1)
103+
return output_gradients + [DisconnectedType.subtype()()] * (len(input) - 1)
104104

105105
def connection_pattern(self, node):
106106
return [[1]] + [[0]] * (len(node.inputs) - 1)

aesara/sandbox/multinomial.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def c_code(self, node, name, ins, outs, sub):
7171
if self.odtype == "auto":
7272
t = f"PyArray_TYPE({pvals})"
7373
else:
74-
t = ScalarType(self.odtype).dtype_specs()[1]
74+
t = ScalarType.subtype(self.odtype).dtype_specs()[1]
7575
if t.startswith("aesara_complex"):
7676
t = t.replace("aesara_complex", "NPY_COMPLEX")
7777
else:
@@ -263,7 +263,7 @@ def c_code(self, node, name, ins, outs, sub):
263263
if self.odtype == "auto":
264264
t = "NPY_INT64"
265265
else:
266-
t = ScalarType(self.odtype).dtype_specs()[1]
266+
t = ScalarType.subtype(self.odtype).dtype_specs()[1]
267267
if t.startswith("aesara_complex"):
268268
t = t.replace("aesara_complex", "NPY_COMPLEX")
269269
else:

aesara/sandbox/rng_mrg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def mrg_next_value(rstate, new_rstate, NORM, mask, offset):
325325
class mrg_uniform_base(Op):
326326
# TODO : need description for class, parameter
327327
__props__ = ("output_type", "inplace")
328-
params_type = ParamsType(
328+
params_type = ParamsType.subtype(
329329
inplace=bool_t,
330330
# following params will come from self.output_type.
331331
# NB: As output object may not be allocated in C code,
@@ -392,7 +392,7 @@ def new(cls, rstate, ndim, dtype, size):
392392
v_size = as_tensor_variable(size)
393393
if ndim is None:
394394
ndim = get_vector_length(v_size)
395-
op = cls(TensorType(dtype, (False,) * ndim))
395+
op = cls(TensorType.subtype(dtype, (False,) * ndim))
396396
return op(rstate, v_size)
397397

398398
def perform(self, node, inp, out, params):

aesara/scalar/basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def __init__(self, dtype):
298298
def clone(self, dtype=None, **kwargs):
299299
if dtype is None:
300300
dtype = self.dtype
301-
return type(self)(dtype)
301+
return type(self).subtype(dtype)
302302

303303
@staticmethod
304304
def may_share_memory(a, b):
@@ -679,7 +679,7 @@ def get_scalar_type(dtype, cache: Dict[str, ScalarType] = {}) -> ScalarType:
679679
680680
"""
681681
if dtype not in cache:
682-
cache[dtype] = ScalarType(dtype=dtype)
682+
cache[dtype] = ScalarType.subtype(dtype=dtype)
683683
return cache[dtype]
684684

685685

@@ -2405,13 +2405,13 @@ def grad(self, inputs, gout):
24052405
(gz,) = gout
24062406
if y.type in continuous_types:
24072407
# x is disconnected because the elements of x are not used
2408-
return DisconnectedType()(), gz
2408+
return DisconnectedType.subtype()(), gz
24092409
else:
24102410
# when y is discrete, we assume the function can be extended
24112411
# to deal with real-valued inputs by rounding them to the
24122412
# nearest integer. f(x+eps) thus equals f(x) so the gradient
24132413
# is zero, not disconnected or undefined
2414-
return DisconnectedType()(), y.zeros_like()
2414+
return DisconnectedType.subtype()(), y.zeros_like()
24152415

24162416

24172417
second = Second(transfer_type(1), name="second")

0 commit comments

Comments
 (0)