Skip to content

Commit 8a6c159

Browse files
committed
Merge branch 'main' into pypi
2 parents 148e942 + 9775f64 commit 8a6c159

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed

autoemulate/plotting.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ def _validate_inputs(cv_results, model_name):
2929
)
3030

3131

32-
def check_multioutput(y, output_index):
32+
def _check_multioutput(y, output_index):
3333
"""Checks if y is multi-output and if the output_index is valid."""
3434
if y.ndim > 1:
3535
if (output_index > y.shape[1] - 1) | (output_index < 0):
3636
raise ValueError(
3737
f"Output index {output_index} is out of range. The index should be between 0 and {y.shape[1] - 1}."
3838
)
3939
print(
40-
f"""Multiple outputs detected. Plotting the output variable with index {output_index}.
40+
f"""Plotting the output variable with index {output_index}.
4141
To plot other outputs, set `output_index` argument to the desired index."""
4242
)
4343

@@ -148,6 +148,8 @@ def _plot_single_fold(
148148
y_test_std,
149149
ax,
150150
title=f"{model_name} - {title_suffix}",
151+
input_index=input_index,
152+
output_index=output_index,
151153
)
152154
else:
153155
display = PredictionErrorDisplay.from_predictions(
@@ -334,7 +336,7 @@ def _plot_cv(
334336
"""
335337

336338
_validate_inputs(cv_results, model_name)
337-
check_multioutput(y, output_index)
339+
_check_multioutput(y, output_index)
338340

339341
if model_name:
340342
figure = _plot_model_folds(
@@ -449,7 +451,9 @@ def _plot_model(
449451
y_pred[:, out_idx],
450452
y_std[:, out_idx] if y_std is not None else None,
451453
ax=axs[plot_index],
452-
title=f"X{in_idx} vs. y{out_idx}",
454+
title=f"$X_{in_idx}$ vs. $y_{out_idx}$",
455+
input_index=in_idx,
456+
output_index=out_idx,
453457
)
454458
plot_index += 1
455459
else:
@@ -479,7 +483,9 @@ def _plot_model(
479483
return fig
480484

481485

482-
def _plot_Xy(X, y, y_pred, y_std=None, ax=None, title="Xy"):
486+
def _plot_Xy(
487+
X, y, y_pred, y_std=None, ax=None, title="Xy", input_index=0, output_index=0
488+
):
483489
"""
484490
Plots observed and predicted values vs. features, including 2σ error bands where available.
485491
"""
@@ -533,9 +539,9 @@ def _plot_Xy(X, y, y_pred, y_std=None, ax=None, title="Xy"):
533539
label="pred.",
534540
)
535541

536-
ax.set_xlabel("X")
537-
ax.set_ylabel("y")
538-
ax.set_title(title)
542+
ax.set_xlabel(f"$X_{input_index}$", fontsize=13)
543+
ax.set_ylabel(f"$y_{output_index}$", fontsize=13)
544+
ax.set_title(title, fontsize=13)
539545
ax.grid(True, alpha=0.3)
540546

541547
# Get the handles and labels for the scatter plots

tests/test_plotting.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77

88
from autoemulate.compare import AutoEmulate
99
from autoemulate.emulators import RadialBasisFunctions
10+
from autoemulate.plotting import _check_multioutput
1011
from autoemulate.plotting import _plot_cv
1112
from autoemulate.plotting import _plot_model
1213
from autoemulate.plotting import _plot_single_fold
1314
from autoemulate.plotting import _predict_with_optional_std
1415
from autoemulate.plotting import _validate_inputs
15-
from autoemulate.plotting import check_multioutput
1616

1717

1818
@pytest.fixture(scope="module")
@@ -72,7 +72,7 @@ def test_check_multioutput_with_single_output():
7272
y = np.array([1, 2, 3, 4, 5])
7373
output_index = 0
7474
try:
75-
check_multioutput(y, output_index)
75+
_check_multioutput(y, output_index)
7676
except ValueError as e:
7777
assert False, f"Unexpected ValueError: {str(e)}"
7878

@@ -81,7 +81,7 @@ def test_check_multioutput_with_multioutput():
8181
y = np.array([[1, 2, 3], [4, 5, 6]])
8282
output_index = 1
8383
try:
84-
check_multioutput(y, output_index)
84+
_check_multioutput(y, output_index)
8585
except ValueError as e:
8686
assert False, f"Unexpected ValueError: {str(e)}"
8787

@@ -90,7 +90,7 @@ def test_check_multioutput_with_invalid_output_index():
9090
y = np.array([[1, 2, 3], [4, 5, 6]])
9191
output_index = 3
9292
try:
93-
check_multioutput(y, output_index)
93+
_check_multioutput(y, output_index)
9494
assert False, "Expected ValueError to be raised"
9595
except ValueError as e:
9696
assert (
@@ -354,7 +354,7 @@ def test__plot_model_int(ae_single_output):
354354
output_index=0,
355355
)
356356
assert isinstance(fig, plt.Figure)
357-
assert fig.axes[0].get_title() == "X0 vs. y0"
357+
assert all(term in fig.axes[0].get_title() for term in ["X", "y", "vs."])
358358

359359

360360
def test__plot_model_list(ae_single_output):
@@ -367,7 +367,7 @@ def test__plot_model_list(ae_single_output):
367367
output_index=[0],
368368
)
369369
assert isinstance(fig, plt.Figure)
370-
assert fig.axes[1].get_title() == "X1 vs. y0"
370+
assert all(term in fig.axes[1].get_title() for term in ["X", "y", "vs."])
371371

372372

373373
def test__plot_model_int_out_of_range(ae_single_output):

0 commit comments

Comments
 (0)