Skip to content

Commit 937b209

Browse files
better fix without validating data twice
1 parent 0632be3 commit 937b209

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

sklearnex/neighbors/common.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030

3131
from .._utils import PatchingConditionsChain
3232
from ..base import oneDALEstimator
33-
from ..utils._array_api import get_namespace
3433
from ..utils.validation import check_feature_names
3534

3635

@@ -68,7 +67,10 @@ def _fit_validation(self, X, y=None):
6867

6968
if not isinstance(X, (KDTree, BallTree, _sklearn_NeighborsBase)):
7069
self._fit_X = _check_array(
71-
X, dtype=[np.float64, np.float32], accept_sparse=True
70+
X,
71+
dtype=[np.float64, np.float32],
72+
accept_sparse=True,
73+
force_all_finite=not self.effective_metric_.startswith("nan"),
7274
)
7375
self.n_samples_fit_ = _num_samples(self._fit_X)
7476
self.n_features_in_ = _num_features(self._fit_X)

0 commit comments

Comments
 (0)