1414from itertools import chain
1515from itertools import product as itertools_product
1616from logging import Logger
17- from typing import Optional
17+ from typing import TYPE_CHECKING , Optional , Union
1818from warnings import warn
1919
2020import numpy as np
21+ from typing_extensions import Literal
2122
2223import aesara
2324from aesara .compile .function .types import (
4243from aesara .utils import NoDuplicateOptWarningFilter , difference , get_unbound_function
4344
4445
45- __docformat__ = "restructuredtext en"
46+ if TYPE_CHECKING :
47+ from aesara .graph .basic import Apply
48+
4649_logger : Logger = logging .getLogger ("aesara.compile.debugmode" )
4750_logger .addFilter (NoDuplicateOptWarningFilter ())
4851
@@ -1108,43 +1111,32 @@ class _FunctionGraphEvent:
11081111
11091112 """
11101113
1111- kind = ""
1112- """
1113- One of 'import', 'change', 'prune'.
1114-
1115- """
1116-
1117- node = None
1118- """
1119- Either 'output' or an Apply instance.
1120-
1121- """
1122-
1123- op = None
1124- """Either 'output' or an Op instance"""
1114+ kind : Literal ["import" , "change" , "prune" ]
1115+ old_node : Optional [Union [Literal ["output" ], "Apply" ]]
1116+ new_node : Optional [Union [Literal ["output" ], "Apply" ]]
1117+ op : Optional [Union [Literal ["output" ], Op ]]
1118+ idx : Optional [int ]
1119+ reason : Optional [str ]
11251120
1126- idx = None
1127- """
1128- Change events involve an position index of the input variable.
1129-
1130- """
1131-
1132- reason = None
1133- """
1134- Change events sometimes have a reason.
1135-
1136- """
1137-
1138- def __init__ (self , kind , node , idx = None , reason = None ):
1121+ def __init__ (
1122+ self ,
1123+ kind : Literal ["import" , "change" , "prune" ],
1124+ old_node : Union [Literal ["output" ], "Apply" ],
1125+ new_node : Union [Literal ["output" ], "Apply" ] = None ,
1126+ idx : Optional [int ] = None ,
1127+ reason : Optional [str ] = None ,
1128+ ):
11391129 self .kind = kind
1140- if node == "output" :
1141- self .node = "output"
1130+ if old_node == "output" :
1131+ self .old_node = "output"
1132+ self .new_node = "output"
11421133 self .op = "output"
11431134 else :
1144- self .node = node
1145- self .op = node .op
1135+ self .old_node = old_node
1136+ self .new_node = new_node
1137+ self .op = old_node .op
11461138 self .idx = idx
1147- self .reason = str (reason )
1139+ self .reason = str (reason ) if reason else None
11481140
11491141 def __str__ (self ):
11501142 if self .kind == "change" :
@@ -1218,21 +1210,21 @@ def on_attach(self, fgraph):
12181210 self .replaced_by = {}
12191211 self .event_list = []
12201212 for node in fgraph .toposort ():
1221- self .on_import (fgraph , node , "on_attach" )
1213+ self .on_import (fgraph , node , reason = "on_attach" )
12221214
12231215 def on_detach (self , fgraph ):
12241216 assert fgraph is self .fgraph
12251217 self .fgraph = None
12261218
12271219 def on_prune (self , fgraph , node , reason ):
1228- self .event_list .append (_FunctionGraphEvent ("prune" , node , reason = str ( reason ) ))
1220+ self .event_list .append (_FunctionGraphEvent ("prune" , node , reason = reason ))
12291221 assert node in self .active_nodes
12301222 assert node not in self .inactive_nodes
12311223 self .active_nodes .remove (node )
12321224 self .inactive_nodes .add (node )
12331225
12341226 def on_import (self , fgraph , node , reason ):
1235- self .event_list .append (_FunctionGraphEvent ("import" , node , reason = str ( reason ) ))
1227+ self .event_list .append (_FunctionGraphEvent ("import" , node , reason = reason ))
12361228
12371229 assert node not in self .active_nodes
12381230 self .active_nodes .add (node )
@@ -1252,31 +1244,36 @@ def on_import(self, fgraph, node, reason):
12521244 self .reasons .setdefault (r , [])
12531245 self .replaced_by .setdefault (r , [])
12541246
1255- def on_change_input (self , fgraph , node , i , r , new_r , reason = None ):
1247+ def on_change_input (
1248+ self , fgraph , old_node , new_node , i , old_var , new_var , reason = None
1249+ ):
12561250 reason = str (reason )
12571251 self .event_list .append (
1258- _FunctionGraphEvent ("change" , node , reason = reason , idx = i )
1252+ _FunctionGraphEvent ("change" , old_node , new_node , idx = i , reason = reason )
12591253 )
12601254
1261- self .reasons .setdefault (new_r , [])
1262- self .replaced_by .setdefault (new_r , [])
1255+ self .on_import (fgraph , new_node , reason = reason )
1256+ self .on_prune (fgraph , old_node , reason = reason )
1257+
1258+ self .reasons .setdefault (new_var , [])
1259+ self .replaced_by .setdefault (new_var , [])
12631260
12641261 append_reason = True
1265- for tup in self .reasons [new_r ]:
1266- if tup [0 ] == reason and tup [1 ] is r :
1262+ for tup in self .reasons [new_var ]:
1263+ if tup [0 ] == reason and tup [1 ] is old_var :
12671264 append_reason = False
12681265
12691266 if append_reason :
12701267 # N.B. compute the debugprint now, because future
12711268 # optimizations will change the graph
12721269 done = dict ()
12731270 used_ids = dict ()
1274- self .reasons [new_r ].append (
1271+ self .reasons [new_var ].append (
12751272 (
12761273 reason ,
1277- r ,
1274+ old_var ,
12781275 _debugprint (
1279- r ,
1276+ old_var ,
12801277 prefix = " " ,
12811278 depth = 6 ,
12821279 file = StringIO (),
@@ -1285,7 +1282,7 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
12851282 used_ids = used_ids ,
12861283 ).getvalue (),
12871284 _debugprint (
1288- new_r ,
1285+ new_var ,
12891286 prefix = " " ,
12901287 depth = 6 ,
12911288 file = StringIO (),
@@ -1295,22 +1292,22 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
12951292 ).getvalue (),
12961293 )
12971294 )
1298- self .replaced_by [r ].append ((reason , new_r ))
1295+ self .replaced_by [old_var ].append ((reason , new_var ))
12991296
1300- if r in self .equiv :
1301- r_set = self .equiv [r ]
1297+ if old_var in self .equiv :
1298+ r_set = self .equiv [old_var ]
13021299 else :
1303- r_set = self .equiv .setdefault (r , {r })
1304- self .all_variables_ever .append (r )
1300+ r_set = self .equiv .setdefault (old_var , {old_var })
1301+ self .all_variables_ever .append (old_var )
13051302
1306- if new_r in self .equiv :
1307- new_r_set = self .equiv [new_r ]
1303+ if new_var in self .equiv :
1304+ new_r_set = self .equiv [new_var ]
13081305 else :
1309- new_r_set = self .equiv .setdefault (new_r , {new_r })
1310- self .all_variables_ever .append (new_r )
1306+ new_r_set = self .equiv .setdefault (new_var , {new_var })
1307+ self .all_variables_ever .append (new_var )
13111308
1312- assert new_r in new_r_set
1313- assert r in r_set
1309+ assert new_var in new_r_set
1310+ assert old_var in r_set
13141311
13151312 # update one equivalence set to contain the other
13161313 # transfer all the elements of the old one to the new one
@@ -1319,8 +1316,8 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
13191316 self .equiv [like_new_r ] = r_set
13201317 assert like_new_r in r_set
13211318
1322- assert self .equiv [r ] is r_set
1323- assert self .equiv [new_r ] is r_set
1319+ assert self .equiv [old_var ] is r_set
1320+ assert self .equiv [new_var ] is r_set
13241321
13251322 def printstuff (self ):
13261323 for key in self .equiv :
0 commit comments