diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 5356c30c9..91e300a71 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -791,9 +791,19 @@ def calculate_empty_frame_accuracy(self, ground_df, predictions_df): def on_validation_epoch_end(self): """Compute metrics.""" + # Check if we are in a standalone validation run + if self.trainer is not None: + is_validate_phase = self.trainer.state.fn == "validate" + else: + is_validate_phase = False + + # Evaluate every n epochs or during standalone validation + evaluate_this_epoch = is_validate_phase or ( + self.config["validation"]["val_accuracy_interval"] + <= self.config["train"]["epochs"] and + self.current_epoch % self.config["validation"]["val_accuracy_interval"] == 0) - #Evaluate every n epochs - if self.current_epoch % self.config["validation"]["val_accuracy_interval"] == 0: + if evaluate_this_epoch: if len(self.predictions) == 0: return None diff --git a/tests/test_main.py b/tests/test_main.py index 333b0ffc5..cbd8774ea 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -995,3 +995,14 @@ def test_set_labels_invalid_length(m): # Expect a ValueError when setting an inv invalid_mapping = {"Object": 0, "Extra": 1} with pytest.raises(ValueError): m.set_labels(invalid_mapping) + +def test_validation_interval_greater_than_epochs(m): + # Set interval higher than max_epochs to disable evaluation + m.config["validation"]["val_accuracy_interval"] = 3 + m.config["train"]["epochs"] = 2 + m.create_trainer() + m.trainer.fit(m) + + assert "box_precision" not in m.trainer.logged_metrics + assert "box_recall" not in m.trainer.logged_metrics + assert "empty_frame_accuracy" not in m.trainer.logged_metrics \ No newline at end of file