Skip to content

Commit 8880657

Browse files
committed
refactor: Move ML model class from lstm_model.py to ml_model.py
1 parent fa1efe1 commit 8880657

File tree

2 files changed

+5
-13
lines changed

2 files changed

+5
-13
lines changed

experiment/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from experiment.plot import save_plot
2424
from experiment.src.data_loader import read_detected_data, read_metadata, join_label, get_y_labels
2525
from experiment.src.features import prepare_data
26-
from experiment.src.lstm_model import MlModel
26+
from experiment.src.ml_model import MlModel
2727
from experiment.src.model_config_preprocess import model_config_preprocess
2828
from experiment.src.prepare_data import prepare_train_data, data_checksum
2929

experiment/src/lstm_model.py renamed to experiment/src/ml_model.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,7 @@
1212

1313
class MlModel(nn.Module):
1414

15-
def __init__(self,
16-
line_shape: tuple,
17-
variable_shape: tuple,
18-
value_shape: tuple,
19-
feature_shape: tuple,
20-
hp=None):
15+
def __init__(self, line_shape: tuple, variable_shape: tuple, value_shape: tuple, feature_shape: tuple, hp=None):
2116
super(MlModel, self).__init__()
2217
if hp is None:
2318
hp = {}
@@ -56,8 +51,8 @@ def __init__(self,
5651
self.b_dropout = nn.Dropout(dense_b_dropout_rate)
5752

5853
@staticmethod
59-
def __get_hyperparam(param_name: str, hp=None) -> Any:
60-
if param := hp.get(param_name):
54+
def __get_hyperparam(param_name: str, hyperparameters=None) -> Any:
55+
if param := hyperparameters.get(param_name):
6156
if isinstance(param, float):
6257
print(f"'{param_name}' is {param}")
6358
return param
@@ -66,10 +61,7 @@ def __get_hyperparam(param_name: str, hp=None) -> Any:
6661
else:
6762
raise ValueError(f"'{param_name}' was not defined during initialization of the model.")
6863

69-
def forward(self,
70-
line_input: torch.Tensor,
71-
variable_input: torch.Tensor,
72-
value_input: torch.Tensor,
64+
def forward(self, line_input: torch.Tensor, variable_input: torch.Tensor, value_input: torch.Tensor,
7365
feature_input: torch.Tensor):
7466
line_out, _ = self.line_lstm(line_input)
7567
line_out = self.line_dropout(line_out[:, -1, :])

0 commit comments

Comments
 (0)