Skip to content

Commit d3726c4

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 74bae26 commit d3726c4

File tree

1 file changed

+25
-21
lines changed

1 file changed

+25
-21
lines changed

src/pytorch_tabular/tabular_model.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,9 @@ def __init__(
133133
trainer_config = self._read_parse_config(trainer_config, TrainerConfig)
134134
optimizer_config = self._read_parse_config(optimizer_config, OptimizerConfig)
135135
if model_config.task != "ssl":
136-
assert data_config.target is not None, (
137-
f"`target` in data_config should not be None for {model_config.task} task"
138-
)
136+
assert (
137+
data_config.target is not None
138+
), f"`target` in data_config should not be None for {model_config.task} task"
139139
if experiment_config is None:
140140
if self.verbose:
141141
logger.info("Experiment Tracking is turned off")
@@ -763,13 +763,13 @@ def fit(
763763
pl.Trainer: The PyTorch Lightning Trainer instance
764764
765765
"""
766-
assert self.config.task != "ssl", (
767-
"`fit` is not valid for SSL task. Please use `pretrain` for semi-supervised learning"
768-
)
766+
assert (
767+
self.config.task != "ssl"
768+
), "`fit` is not valid for SSL task. Please use `pretrain` for semi-supervised learning"
769769
if metrics is not None:
770-
assert len(metrics) == len(metrics_prob_inputs or []), (
771-
"The length of `metrics` and `metrics_prob_inputs` should be equal"
772-
)
770+
assert len(metrics) == len(
771+
metrics_prob_inputs or []
772+
), "The length of `metrics` and `metrics_prob_inputs` should be equal"
773773
seed = seed or self.config.seed
774774
if seed:
775775
seed_everything(seed)
@@ -846,9 +846,9 @@ def pretrain(
846846
pl.Trainer: The PyTorch Lightning Trainer instance
847847
848848
"""
849-
assert self.config.task == "ssl", (
850-
f"`pretrain` is not valid for {self.config.task} task. Please use `fit` instead."
851-
)
849+
assert (
850+
self.config.task == "ssl"
851+
), f"`pretrain` is not valid for {self.config.task} task. Please use `fit` instead."
852852
seed = seed or self.config.seed
853853
if seed:
854854
seed_everything(seed)
@@ -968,9 +968,9 @@ def create_finetune_model(
968968
config = self.config
969969
optimizer_params = optimizer_params or {}
970970
if target is None:
971-
assert hasattr(config, "target") and config.target is not None, (
972-
"`target` cannot be None if it was not set in the initial `DataConfig`"
973-
)
971+
assert (
972+
hasattr(config, "target") and config.target is not None
973+
), "`target` cannot be None if it was not set in the initial `DataConfig`"
974974
else:
975975
assert isinstance(target, list), "`target` should be a list of strings"
976976
config.target = target
@@ -1097,9 +1097,9 @@ def finetune(
10971097
pl.Trainer: The trainer object
10981098
10991099
"""
1100-
assert self._is_finetune_model, (
1101-
"finetune() can only be called on a finetune model created using `TabularModel.create_finetune_model()`"
1102-
)
1100+
assert (
1101+
self._is_finetune_model
1102+
), "finetune() can only be called on a finetune model created using `TabularModel.create_finetune_model()`"
11031103
seed_everything(self.config.seed)
11041104
if freeze_backbone:
11051105
for param in self.model.backbone.parameters():
@@ -2361,9 +2361,13 @@ def bagging_predict(
23612361
"regression",
23622362
], "Bagging is only available for classification and regression"
23632363
if not callable(aggregate):
2364-
assert aggregate in ["mean", "median", "min", "max", "hard_voting"], (
2365-
"aggregate should be one of 'mean', 'median', 'min', 'max', or 'hard_voting'"
2366-
)
2364+
assert aggregate in [
2365+
"mean",
2366+
"median",
2367+
"min",
2368+
"max",
2369+
"hard_voting",
2370+
], "aggregate should be one of 'mean', 'median', 'min', 'max', or 'hard_voting'"
23672371
if self.config.task == "regression":
23682372
assert aggregate != "hard_voting", "hard_voting is only available for classification"
23692373
cv = self._check_cv(cv)

0 commit comments

Comments
 (0)