Skip to content

Commit f743659

Browse files
committed
dropping duplicates in uq_df before merge to avoid cross-product
1 parent ccc1d56 commit f743659

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

src/workbench/core/artifacts/endpoint_core.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -475,17 +475,20 @@ def cross_fold_inference(self, nfolds: int = 5) -> Tuple[dict, pd.DataFrame]:
475475
training_df = fs.view("training").pull_dataframe()
476476

477477
# Run inference on the endpoint to get UQ outputs
478-
full_inference_df = self.inference(training_df)
478+
uq_df = self.inference(training_df)
479479

480480
# Identify UQ-specific columns (quantiles and prediction_std)
481-
uq_columns = [col for col in full_inference_df.columns if col.startswith("q_") or col == "prediction_std"]
481+
uq_columns = [col for col in uq_df.columns if col.startswith("q_") or col == "prediction_std"]
482482

483483
# Merge UQ columns with out-of-fold predictions
484484
if uq_columns:
485-
# Keep id_column and UQ columns, drop 'prediction' to avoid conflict
486-
merge_columns = [id_column] + uq_columns
487-
uq_df = full_inference_df[merge_columns]
485+
# Keep id_column and UQ columns, drop 'prediction' to avoid conflict when merging
486+
uq_df = uq_df[[id_column] + uq_columns]
488487

488+
# Drop duplicates in uq_df based on id_column
489+
uq_df = uq_df.drop_duplicates(subset=[id_column])
490+
491+
# Merge UQ columns into out_of_fold_df
489492
out_of_fold_df = pd.merge(out_of_fold_df, uq_df, on=id_column, how="left")
490493
additional_columns = uq_columns
491494
self.log.info(f"Added UQ columns: {', '.join(additional_columns)}")

0 commit comments

Comments
 (0)