@@ -29,15 +29,15 @@ def _validate_inputs(cv_results, model_name):
2929 )
3030
3131
32- def check_multioutput (y , output_index ):
32+ def _check_multioutput (y , output_index ):
3333 """Checks if y is multi-output and if the output_index is valid."""
3434 if y .ndim > 1 :
3535 if (output_index > y .shape [1 ] - 1 ) | (output_index < 0 ):
3636 raise ValueError (
3737 f"Output index { output_index } is out of range. The index should be between 0 and { y .shape [1 ] - 1 } ."
3838 )
3939 print (
40- f"""Multiple outputs detected. Plotting the output variable with index { output_index } .
40+ f"""Plotting the output variable with index { output_index } .
4141To plot other outputs, set `output_index` argument to the desired index."""
4242 )
4343
@@ -148,6 +148,8 @@ def _plot_single_fold(
148148 y_test_std ,
149149 ax ,
150150 title = f"{ model_name } - { title_suffix } " ,
151+ input_index = input_index ,
152+ output_index = output_index ,
151153 )
152154 else :
153155 display = PredictionErrorDisplay .from_predictions (
@@ -334,7 +336,7 @@ def _plot_cv(
334336 """
335337
336338 _validate_inputs (cv_results , model_name )
337- check_multioutput (y , output_index )
339+ _check_multioutput (y , output_index )
338340
339341 if model_name :
340342 figure = _plot_model_folds (
@@ -449,7 +451,9 @@ def _plot_model(
449451 y_pred [:, out_idx ],
450452 y_std [:, out_idx ] if y_std is not None else None ,
451453 ax = axs [plot_index ],
452- title = f"X{ in_idx } vs. y{ out_idx } " ,
454+ title = f"$X_{ in_idx } $ vs. $y_{ out_idx } $" ,
455+ input_index = in_idx ,
456+ output_index = out_idx ,
453457 )
454458 plot_index += 1
455459 else :
@@ -479,7 +483,9 @@ def _plot_model(
479483 return fig
480484
481485
482- def _plot_Xy (X , y , y_pred , y_std = None , ax = None , title = "Xy" ):
486+ def _plot_Xy (
487+ X , y , y_pred , y_std = None , ax = None , title = "Xy" , input_index = 0 , output_index = 0
488+ ):
483489 """
484490 Plots observed and predicted values vs. features, including 2σ error bands where available.
485491 """
@@ -533,9 +539,9 @@ def _plot_Xy(X, y, y_pred, y_std=None, ax=None, title="Xy"):
533539 label = "pred." ,
534540 )
535541
536- ax .set_xlabel ("X" )
537- ax .set_ylabel ("y" )
538- ax .set_title (title )
542+ ax .set_xlabel (f"$X_ { input_index } $" , fontsize = 13 )
543+ ax .set_ylabel (f"$y_ { output_index } $" , fontsize = 13 )
544+ ax .set_title (title , fontsize = 13 )
539545 ax .grid (True , alpha = 0.3 )
540546
541547 # Get the handles and labels for the scatter plots
0 commit comments