Skip to content

Commit 415fd83

Browse files
thibsejGaelVaroquaux
authored andcommitted
[MRG+2] Add float32 support for Linear Discriminant Analysis (scikit-learn#13273)
* [skip ci] Empty commit to trigger PR * Add dtype testing * Fix: dtype testing * Fix test_estimators[OneVsRestClassifier-check_estimators_dtypes] * TST refactor using parametrize + Add failing test for int32 * Fix for int32 * Fix code according to review + Fix PEP8 violation * Fix dtype for int32 and complex * Fix pep8 violation * Update whatsnew + test COSMIT
1 parent b65afbf commit 415fd83

File tree

3 files changed

+35
-3
lines changed

3 files changed

+35
-3
lines changed

doc/whats_new/v0.21.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ Support for Python 3.4 and below has been officially dropped.
9393
:mod:`sklearn.discriminant_analysis`
9494
....................................
9595

96+
- |Enhancement| :class:`discriminant_analysis.LinearDiscriminantAnalysis` now
97+
preserves ``float32`` and ``float64`` dtypes. :issues:`8769` and
98+
:issues:`11000` by :user:`Thibault Sejourne <thibsej>`
99+
96100
- |Fix| A ``ChangedBehaviourWarning`` is now raised when
97101
:class:`discriminant_analysis.LinearDiscriminantAnalysis` is given as
98102
parameter ``n_components > min(n_features, n_classes - 1)``, and

sklearn/discriminant_analysis.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,8 @@ def fit(self, X, y):
427427
Target values.
428428
"""
429429
# FIXME: Future warning to be removed in 0.23
430-
X, y = check_X_y(X, y, ensure_min_samples=2, estimator=self)
430+
X, y = check_X_y(X, y, ensure_min_samples=2, estimator=self,
431+
dtype=[np.float64, np.float32])
431432
self.classes_ = unique_labels(y)
432433
n_samples, _ = X.shape
433434
n_classes = len(self.classes_)
@@ -485,9 +486,10 @@ def fit(self, X, y):
485486
raise ValueError("unknown solver {} (valid solvers are 'svd', "
486487
"'lsqr', and 'eigen').".format(self.solver))
487488
if self.classes_.size == 2: # treat binary case as a special case
488-
self.coef_ = np.array(self.coef_[1, :] - self.coef_[0, :], ndmin=2)
489+
self.coef_ = np.array(self.coef_[1, :] - self.coef_[0, :], ndmin=2,
490+
dtype=X.dtype)
489491
self.intercept_ = np.array(self.intercept_[1] - self.intercept_[0],
490-
ndmin=1)
492+
ndmin=1, dtype=X.dtype)
491493
return self
492494

493495
def transform(self, X):

sklearn/tests/test_discriminant_analysis.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sklearn.utils.testing import (assert_array_equal, assert_no_warnings,
88
assert_warns_message)
99
from sklearn.utils.testing import assert_array_almost_equal
10+
from sklearn.utils.testing import assert_allclose
1011
from sklearn.utils.testing import assert_equal
1112
from sklearn.utils.testing import assert_almost_equal
1213
from sklearn.utils.testing import assert_raises
@@ -296,6 +297,31 @@ def test_lda_dimension_warning(n_classes, n_features):
296297
assert_warns_message(FutureWarning, future_msg, lda.fit, X, y)
297298

298299

300+
@pytest.mark.parametrize("data_type, expected_type", [
301+
(np.float32, np.float32),
302+
(np.float64, np.float64),
303+
(np.int32, np.float64),
304+
(np.int64, np.float64)
305+
])
306+
def test_lda_dtype_match(data_type, expected_type):
307+
for (solver, shrinkage) in solver_shrinkage:
308+
clf = LinearDiscriminantAnalysis(solver=solver, shrinkage=shrinkage)
309+
clf.fit(X.astype(data_type), y.astype(data_type))
310+
assert clf.coef_.dtype == expected_type
311+
312+
313+
def test_lda_numeric_consistency_float32_float64():
314+
for (solver, shrinkage) in solver_shrinkage:
315+
clf_32 = LinearDiscriminantAnalysis(solver=solver, shrinkage=shrinkage)
316+
clf_32.fit(X.astype(np.float32), y.astype(np.float32))
317+
clf_64 = LinearDiscriminantAnalysis(solver=solver, shrinkage=shrinkage)
318+
clf_64.fit(X.astype(np.float64), y.astype(np.float64))
319+
320+
# Check value consistency between types
321+
rtol = 1e-6
322+
assert_allclose(clf_32.coef_, clf_64.coef_, rtol=rtol)
323+
324+
299325
def test_qda():
300326
# QDA classification.
301327
# This checks that QDA implements fit and predict and returns

0 commit comments

Comments
 (0)