Skip to content

Commit 0632be3

Browse files
authored
FIx: raw incremental linreg and revised test scope (#2816)
* FIx: raw incremental linreg and revised test scope * oops * oops 2 * lint
1 parent 595cfbb commit 0632be3

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

sklearnex/linear_model/incremental_linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,12 +162,12 @@ def _onedal_supported(self, method_name, *data):
162162
_onedal_gpu_supported = _onedal_supported
163163

164164
def _onedal_predict(self, X, queue=None):
165+
xp, _ = get_namespace(X)
166+
165167
if not get_config()["use_raw_input"]:
166168
if sklearn_check_version("1.2"):
167169
self._validate_params()
168170

169-
xp, _ = get_namespace(X)
170-
171171
X = validate_data(
172172
self,
173173
X,

sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,9 @@ def test_incremental_linear_regression_fit_spmd_gold(
118118
@pytest.mark.parametrize("num_blocks", [1, 2])
119119
@pytest.mark.parametrize("macro_block", [None, 1024])
120120
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
121-
@pytest.mark.parametrize("use_raw_input", [True, False])
122121
@pytest.mark.mpi
123122
def test_incremental_linear_regression_partial_fit_spmd_gold(
124-
dataframe, queue, fit_intercept, num_blocks, macro_block, dtype, use_raw_input
123+
dataframe, queue, fit_intercept, num_blocks, macro_block, dtype
125124
):
126125
# Import spmd and non-SPMD algo
127126
from sklearnex.linear_model import IncrementalLinearRegression
@@ -179,9 +178,7 @@ def test_incremental_linear_regression_partial_fit_spmd_gold(
179178
local_dpt_y = _convert_to_dataframe(
180179
split_local_y[i], sycl_queue=queue, target_df=dataframe
181180
)
182-
# Configure raw input status for spmd estimator
183-
with config_context(use_raw_input=use_raw_input):
184-
inclin_spmd.partial_fit(local_dpt_X, local_dpt_y)
181+
inclin_spmd.partial_fit(local_dpt_X, local_dpt_y)
185182

186183
inclin.fit(dpt_X, dpt_y)
187184

@@ -272,6 +269,7 @@ def test_incremental_linear_regression_fit_spmd_random(
272269
@pytest.mark.parametrize("num_features", [5, 10])
273270
@pytest.mark.parametrize("macro_block", [None, 1024])
274271
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
272+
@pytest.mark.parametrize("use_raw_input", [True, False])
275273
@pytest.mark.mpi
276274
def test_incremental_linear_regression_partial_fit_spmd_random(
277275
dataframe,
@@ -282,6 +280,7 @@ def test_incremental_linear_regression_partial_fit_spmd_random(
282280
num_features,
283281
macro_block,
284282
dtype,
283+
use_raw_input,
285284
):
286285
# Import spmd and non-SPMD algo
287286
from sklearnex.linear_model import IncrementalLinearRegression
@@ -328,7 +327,9 @@ def test_incremental_linear_regression_partial_fit_spmd_random(
328327
dpt_X = _convert_to_dataframe(X_split[i], sycl_queue=queue, target_df=dataframe)
329328
dpt_y = _convert_to_dataframe(y_split[i], sycl_queue=queue, target_df=dataframe)
330329

331-
inclin_spmd.partial_fit(local_dpt_X, local_dpt_y)
330+
# Configure raw input status for spmd estimator
331+
with config_context(use_raw_input=use_raw_input):
332+
inclin_spmd.partial_fit(local_dpt_X, local_dpt_y)
332333
inclin.partial_fit(dpt_X, dpt_y)
333334

334335
assert_allclose(_as_numpy(inclin.coef_), _as_numpy(inclin_spmd.coef_), atol=tol)
@@ -337,7 +338,8 @@ def test_incremental_linear_regression_partial_fit_spmd_random(
337338
_as_numpy(inclin.intercept_), _as_numpy(inclin_spmd.intercept_), atol=tol
338339
)
339340

340-
y_pred_spmd = inclin_spmd.predict(dpt_X_test)
341+
with config_context(use_raw_input=use_raw_input):
342+
y_pred_spmd = inclin_spmd.predict(dpt_X_test)
341343
y_pred = inclin.predict(dpt_X_test)
342344

343345
assert_allclose(_as_numpy(y_pred_spmd), _as_numpy(y_pred), atol=tol)

0 commit comments

Comments
 (0)