1212
1313class MlModel (nn .Module ):
1414
15- def __init__ (self ,
16- line_shape : tuple ,
17- variable_shape : tuple ,
18- value_shape : tuple ,
19- feature_shape : tuple ,
20- hp = None ):
15+ def __init__ (self , line_shape : tuple , variable_shape : tuple , value_shape : tuple , feature_shape : tuple , hp = None ):
2116 super (MlModel , self ).__init__ ()
2217 if hp is None :
2318 hp = {}
@@ -56,8 +51,8 @@ def __init__(self,
5651 self .b_dropout = nn .Dropout (dense_b_dropout_rate )
5752
5853 @staticmethod
59- def __get_hyperparam (param_name : str , hp = None ) -> Any :
60- if param := hp .get (param_name ):
54+ def __get_hyperparam (param_name : str , hyperparameters = None ) -> Any :
55+ if param := hyperparameters .get (param_name ):
6156 if isinstance (param , float ):
6257 print (f"'{ param_name } ' is { param } " )
6358 return param
@@ -66,10 +61,7 @@ def __get_hyperparam(param_name: str, hp=None) -> Any:
6661 else :
6762 raise ValueError (f"'{ param_name } ' was not defined during initialization of the model." )
6863
69- def forward (self ,
70- line_input : torch .Tensor ,
71- variable_input : torch .Tensor ,
72- value_input : torch .Tensor ,
64+ def forward (self , line_input : torch .Tensor , variable_input : torch .Tensor , value_input : torch .Tensor ,
7365 feature_input : torch .Tensor ):
7466 line_out , _ = self .line_lstm (line_input )
7567 line_out = self .line_dropout (line_out [:, - 1 , :])
0 commit comments