Skip to content

Commit d5307db

Browse files
committed
feat: Implement hyperparameter search with Optuna
1 parent c9b79e3 commit d5307db

File tree

3 files changed

+84
-21
lines changed

3 files changed

+84
-21
lines changed

experiment/main.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from sklearn.metrics import f1_score, precision_score, recall_score, log_loss, accuracy_score
1717
from sklearn.model_selection import train_test_split
1818
from sklearn.utils import compute_class_weight
19+
import optuna
20+
from optuna.samplers import TPESampler
21+
from optuna.pruners import HyperbandPruner
1922

2023
from experiment.plot import save_plot
2124
from experiment.src.data_loader import read_detected_data, read_metadata, join_label, get_y_labels
@@ -25,6 +28,36 @@
2528
from experiment.src.prepare_data import prepare_train_data, data_checksum
2629

2730

31+
def objective(trial, x_train, y_train, x_test, y_test, hp, device):
32+
params = {}
33+
for param_name, ((low, high, step), default) in hp.items():
34+
params[param_name] = trial.suggest_float(param_name, low, high, step=step)
35+
36+
model = MlModel(*[x.shape for x in x_train], params).to(device)
37+
optimizer = optim.Adam(model.parameters(), lr=0.001)
38+
criterion = nn.BCELoss()
39+
40+
dataset = TensorDataset(*[torch.tensor(x, dtype=torch.float32) for x in x_train],
41+
torch.tensor(y_train, dtype=torch.float32))
42+
data_loader = DataLoader(dataset, batch_size=1024, shuffle=True)
43+
44+
model.train()
45+
for _ in range(5):
46+
for batch in data_loader:
47+
x_tensors = [x.to(device) for x in batch[:-1]]
48+
y_batch = batch[-1].to(device)
49+
optimizer.zero_grad()
50+
outputs = model(*x_tensors).squeeze()
51+
loss = criterion(outputs, y_batch)
52+
loss.backward()
53+
optimizer.step()
54+
55+
predictions = model(*[torch.tensor(x, dtype=torch.float32, device=device)
56+
for x in x_test]).cpu().detach().numpy().ravel()
57+
val_loss = criterion(y_test, predictions)
58+
return val_loss
59+
60+
2861
def evaluate_model(thresholds: dict,
2962
model: nn.Module,
3063
x_data: List[np.ndarray],
@@ -169,9 +202,16 @@ def main(cred_data_location: str,
169202
x_train = [x_train_line, x_train_variable, x_train_value, x_train_features]
170203
x_test = [x_test_line, x_test_variable, x_test_value, x_test_features]
171204

172-
param_kwargs = {k: v[1] for k, v in hp_dict.items()}
205+
if use_tuner:
206+
print(f"Start model train with optimization")
207+
study = optuna.create_study(sampler=TPESampler(), pruner=HyperbandPruner(), direction="minimize")
208+
study.optimize(lambda trial: objective(trial, x_train, y_train, x_test, y_test, hp_dict, device), n_trials=20)
209+
param_kwargs = study.best_params
210+
print(f"Best hyperparameters: {param_kwargs}")
211+
else:
212+
param_kwargs = {k: v[1] for k, v in hp_dict.items()}
173213

174-
print(f"Model is trained with params from dict:{param_kwargs}")
214+
print(f"Model will be trained using the following params:{param_kwargs}")
175215

176216
# repeat train step to obtain actual history chart
177217
ml_model = MlModel(x_full_line.shape, x_full_variable.shape, x_full_value.shape, x_full_features.shape,

experiment/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ tensorrt==10.8.0.43
1414
tf2onnx==1.16.1
1515
wrapt==1.14.1
1616
torch==2.6.0
17+
optuna==4.2.1
1718

1819
# version insensetive
1920
types-tensorflow

experiment/src/lstm_model.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,38 @@
99

1010
dtype = torch.float32
1111

12+
1213
class MlModel(nn.Module):
1314

14-
def __init__(
15-
self,
16-
line_shape: tuple,
17-
variable_shape: tuple,
18-
value_shape: tuple,
19-
feature_shape: tuple,
20-
hp=None,
21-
):
15+
def __init__(self,
16+
line_shape: tuple,
17+
variable_shape: tuple,
18+
value_shape: tuple,
19+
feature_shape: tuple,
20+
hp=None):
2221
super(MlModel, self).__init__()
2322
if hp is None:
2423
hp = {}
25-
value_lstm_dropout_rate = hp.get("value_lstm_dropout_rate", 0.45)
26-
line_lstm_dropout_rate = hp.get("line_lstm_dropout_rate", 0.45)
27-
variable_lstm_dropout_rate = hp.get("variable_lstm_dropout_rate", 0.45)
28-
dense_a_dropout_rate = hp.get("dense_a_lstm_dropout_rate", 0.45)
29-
dense_b_dropout_rate = hp.get("dense_b_lstm_dropout_rate", 0.45)
30-
#print(f"Input hyperparameters: {hp}")
31-
print(f"Run model with parameters: value_lstm_dropout_rate={value_lstm_dropout_rate}, line_lstm_dropout_rate={line_lstm_dropout_rate}, variable_lstm_dropout_rate={variable_lstm_dropout_rate}, dense_a_dropout_rate={dense_a_dropout_rate}, dense_b_dropout_rate={dense_b_dropout_rate}")
24+
value_lstm_dropout_rate = self.__get_hyperparam("value_lstm_dropout_rate", hp)
25+
line_lstm_dropout_rate = self.__get_hyperparam("line_lstm_dropout_rate", hp)
26+
variable_lstm_dropout_rate = self.__get_hyperparam("variable_lstm_dropout_rate", hp)
27+
dense_a_dropout_rate = self.__get_hyperparam("dense_a_lstm_dropout_rate", hp)
28+
dense_b_dropout_rate = self.__get_hyperparam("dense_b_lstm_dropout_rate", hp)
29+
3230
self.d_type = torch.float32
3331

34-
self.line_lstm = nn.LSTM(input_size=line_shape[2], hidden_size=line_shape[1], batch_first=True, bidirectional=True)
35-
self.variable_lstm = nn.LSTM(input_size=variable_shape[2], hidden_size=variable_shape[1], batch_first=True, bidirectional=True)
36-
self.value_lstm = nn.LSTM(input_size=value_shape[2], hidden_size=value_shape[1], batch_first=True, bidirectional=True)
32+
self.line_lstm = nn.LSTM(input_size=line_shape[2],
33+
hidden_size=line_shape[1],
34+
batch_first=True,
35+
bidirectional=True)
36+
self.variable_lstm = nn.LSTM(input_size=variable_shape[2],
37+
hidden_size=variable_shape[1],
38+
batch_first=True,
39+
bidirectional=True)
40+
self.value_lstm = nn.LSTM(input_size=value_shape[2],
41+
hidden_size=value_shape[1],
42+
batch_first=True,
43+
bidirectional=True)
3744

3845
self.line_dropout = nn.Dropout(line_lstm_dropout_rate)
3946
self.variable_dropout = nn.Dropout(variable_lstm_dropout_rate)
@@ -48,7 +55,22 @@ def __init__(
4855
self.a_dropout = nn.Dropout(dense_a_dropout_rate)
4956
self.b_dropout = nn.Dropout(dense_b_dropout_rate)
5057

51-
def forward(self, line_input, variable_input, value_input, feature_input):
58+
@staticmethod
59+
def __get_hyperparam(param_name: str, hp=None) -> Any:
60+
if param := hp.get(param_name):
61+
if isinstance(param, float):
62+
print(f"'{param_name}' is {param}")
63+
return param
64+
else:
65+
raise ValueError(f"Unexpected '{param_name}': {param}")
66+
else:
67+
raise ValueError(f"'{param_name}' was not defined during initialization of the model.")
68+
69+
def forward(self,
70+
line_input: torch.Tensor,
71+
variable_input: torch.Tensor,
72+
value_input: torch.Tensor,
73+
feature_input: torch.Tensor):
5274
line_out, _ = self.line_lstm(line_input)
5375
line_out = self.line_dropout(line_out[:, -1, :])
5476

0 commit comments

Comments
 (0)