55
66from aesara .graph import utils
77from aesara .graph .basic import Constant , Variable
8- from aesara .graph .utils import MetaObject
8+ from aesara .graph .utils import MetaType
99
1010
1111D = TypeVar ("D" )
1212
1313
14- class Type (MetaObject , Generic [D ]):
14+ class NewTypeMeta (type ):
15+ # pass
16+ def __call__ (cls , * args , ** kwargs ):
17+ raise RuntimeError ("Use subtype" )
18+ # return super().__call__(*args, **kwargs)
19+
20+ def subtype (cls , * args , ** kwargs ):
21+ return super ().__call__ (* args , ** kwargs )
22+
23+
24+ class Type (Generic [D ], metaclass = NewTypeMeta ):
1525 """
1626 Interface specification for variable type instances.
1727
@@ -35,6 +45,12 @@ class Type(MetaObject, Generic[D]):
3545 The `Type` that will be created by a call to `Type.make_constant`.
3646 """
3747
48+ __props__ : tuple [str , ...] = ()
49+
50+ @classmethod
51+ def create (cls , ** kwargs ):
52+ MetaType (f"{ cls .__name__ } [{ kwargs } ]" , (cls ,), kwargs )
53+
3854 def in_same_class (self , otype : "Type" ) -> Optional [bool ]:
3955 """Determine if another `Type` represents a subset from the same "class" of types represented by `self`.
4056
@@ -214,7 +230,7 @@ def make_constant(self, value: D, name: Optional[Text] = None) -> constant_type:
214230
215231 def clone (self , * args , ** kwargs ) -> "Type" :
216232 """Clone a copy of this type with the given arguments/keyword values, if any."""
217- return type (self )(* args , ** kwargs )
233+ return type (self ). subtype (* args , ** kwargs )
218234
219235 def __call__ (self , name : Optional [Text ] = None ) -> variable_type :
220236 """Return a new `Variable` instance of Type `self`.
@@ -261,6 +277,41 @@ def values_eq_approx(cls, a: D, b: D) -> bool:
261277 """
262278 return cls .values_eq (a , b )
263279
280+ def _props (self ):
281+ """
282+ Tuple of properties of all attributes
283+ """
284+ return tuple (getattr (self , a ) for a in self .__props__ )
285+
286+ def _props_dict (self ):
287+ """This return a dict of all ``__props__`` key-> value.
288+
289+ This is useful in optimization to swap op that should have the
290+ same props. This help detect error that the new op have at
291+ least all the original props.
292+
293+ """
294+ return {a : getattr (self , a ) for a in self .__props__ }
295+
296+ def __hash__ (self ):
297+ return hash ((type (self ), tuple (getattr (self , a ) for a in self .__props__ )))
298+
299+ def __eq__ (self , other ):
300+ return type (self ) == type (other ) and tuple (
301+ getattr (self , a ) for a in self .__props__
302+ ) == tuple (getattr (other , a ) for a in self .__props__ )
303+
304+ def __str__ (self ):
305+ if self .__props__ is None or len (self .__props__ ) == 0 :
306+ return f"{ self .__class__ .__name__ } ()"
307+ else :
308+ return "{}{{{}}}" .format (
309+ self .__class__ .__name__ ,
310+ ", " .join (
311+ "{}={!r}" .format (p , getattr (self , p )) for p in self .__props__
312+ ),
313+ )
314+
264315
265316DataType = str
266317
0 commit comments