Skip to content

Commit fe0a626

Browse files
committed
Add lr scheduler interval config
1 parent 6cc6da1 commit fe0a626

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

src/pytorch_tabular/config/config.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,9 @@ class DataConfig:
192192
)
193193

194194
def __post_init__(self):
195-
assert (
196-
len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0
197-
), "There should be at-least one feature defined in categorical, continuous, or date columns"
195+
assert len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0, (
196+
"There should be at-least one feature defined in categorical, continuous, or date columns"
197+
)
198198
_validate_choices(self)
199199
if os.name == "nt" and self.num_workers != 0:
200200
print("Windows does not support num_workers > 0. Setting num_workers to 0")
@@ -255,9 +255,9 @@ class InferredConfig:
255255

256256
def __post_init__(self):
257257
if self.embedding_dims is not None:
258-
assert all(
259-
(isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims
260-
), "embedding_dims must be a list of tuples (cardinality, embedding_dim)"
258+
assert all((isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims), (
259+
"embedding_dims must be a list of tuples (cardinality, embedding_dim)"
260+
)
261261
self.embedded_cat_dim = sum([t[1] for t in self.embedding_dims])
262262
else:
263263
self.embedded_cat_dim = 0
@@ -677,6 +677,9 @@ class OptimizerConfig:
677677
lr_scheduler_monitor_metric (Optional[str]): Used with ReduceLROnPlateau, where the plateau is
678678
decided based on this metric
679679
680+
lr_scheduler_interval (Optional[str]): Interval at which to step the LR Scheduler, one of "epoch"
681+
or "step". Defaults to `epoch`.
682+
680683
"""
681684

682685
optimizer: str = field(
@@ -709,6 +712,11 @@ class OptimizerConfig:
709712
metadata={"help": "Used with ReduceLROnPlateau, where the plateau is decided based on this metric"},
710713
)
711714

715+
lr_scheduler_interval: Optional[str] = field(
716+
default="epoch",
717+
metadata={"help": "Interval at which to step the LR Scheduler, one of `epoch` or `step`. Defaults to `epoch`."},
718+
)
719+
712720
@staticmethod
713721
def read_from_yaml(filename: str = "config/optimizer_config.yml"):
714722
config = _read_yaml(filename)

src/pytorch_tabular/models/base_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -588,8 +588,11 @@ def configure_optimizers(self):
588588
}
589589
return {
590590
"optimizer": opt,
591-
"lr_scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params),
592-
"monitor": self.hparams.lr_scheduler_monitor_metric,
591+
"lr_scheduler": {
592+
"scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params),
593+
"monitor": self.hparams.lr_scheduler_monitor_metric,
594+
"interval": self.hparams.lr_scheduler_interval,
595+
},
593596
}
594597
else:
595598
return opt

0 commit comments

Comments
 (0)