Skip to content

Commit 3e715fd

Browse files
brentfaganrth
authored andcommitted
FIX Raise warning in scikit-learn/sklearn/linear_model/cd_fast.pyx for cases when the main loop exits without reaching the desired tolerance (scikit-learn#11754)
1 parent 415fd83 commit 3e715fd

File tree

4 files changed

+58
-13
lines changed

4 files changed

+58
-13
lines changed

sklearn/linear_model/cd_fast.pyx

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ cimport cython
1515
from cpython cimport bool
1616
from cython cimport floating
1717
import warnings
18+
from ..exceptions import ConvergenceWarning
1819

1920
from ..utils._cython_blas cimport (_axpy, _dot, _asum, _ger, _gemv, _nrm2,
2021
_copy, _scal)
@@ -246,6 +247,14 @@ def enet_coordinate_descent(floating[::1] w,
246247
if gap < tol:
247248
# return if we reached desired tolerance
248249
break
250+
251+
else:
252+
with gil:
253+
warnings.warn("Objective did not converge."
254+
" You might want to increase the number of iterations."
255+
" Duality gap: {}, tolerance: {}".format(gap, tol),
256+
ConvergenceWarning)
257+
249258
return w, gap, tol, n_iter + 1
250259

251260

@@ -456,6 +465,13 @@ def sparse_enet_coordinate_descent(floating [::1] w,
456465
# return if we reached desired tolerance
457466
break
458467

468+
else:
469+
with gil:
470+
warnings.warn("Objective did not converge."
471+
" You might want to increase the number of iterations."
472+
" Duality gap: {}, tolerance: {}".format(gap, tol),
473+
ConvergenceWarning)
474+
459475
return w, gap, tol, n_iter + 1
460476

461477

@@ -604,6 +620,13 @@ def enet_coordinate_descent_gram(floating[::1] w,
604620
# return if we reached desired tolerance
605621
break
606622

623+
624+
with gil:
625+
warnings.warn("Objective did not converge."
626+
" You might want to increase the number of iterations."
627+
" Duality gap: {}, tolerance: {}".format(gap, tol),
628+
ConvergenceWarning)
629+
607630
return np.asarray(w), gap, tol, n_iter + 1
608631

609632

@@ -794,5 +817,11 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg,
794817
if gap < tol:
795818
# return if we reached desired tolerance
796819
break
820+
else:
821+
with gil:
822+
warnings.warn("Objective did not converge."
823+
" You might want to increase the number of iterations."
824+
" Duality gap: {}, tolerance: {}".format(gap, tol),
825+
ConvergenceWarning)
797826

798827
return np.asarray(W), gap, tol, n_iter + 1

sklearn/linear_model/coordinate_descent.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from ..utils.fixes import _joblib_parallel_args
2424
from ..utils.validation import check_is_fitted
2525
from ..utils.validation import column_or_1d
26-
from ..exceptions import ConvergenceWarning
2726

2827
from . import cd_fast
2928

@@ -481,13 +480,6 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
481480
coefs[..., i] = coef_
482481
dual_gaps[i] = dual_gap_
483482
n_iters.append(n_iter_)
484-
if dual_gap_ > eps_:
485-
warnings.warn('Objective did not converge.' +
486-
' You might want' +
487-
' to increase the number of iterations.' +
488-
' Fitting data with very small alpha' +
489-
' may cause precision problems.',
490-
ConvergenceWarning)
491483

492484
if verbose:
493485
if verbose > 2:
@@ -1812,11 +1804,6 @@ def fit(self, X, y):
18121804

18131805
self._set_intercept(X_offset, y_offset, X_scale)
18141806

1815-
if self.dual_gap_ > self.eps_:
1816-
warnings.warn('Objective did not converge, you might want'
1817-
' to increase the number of iterations',
1818-
ConvergenceWarning)
1819-
18201807
# return self for chaining fit and predict calls
18211808
return self
18221809

sklearn/linear_model/tests/test_coordinate_descent.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,3 +828,20 @@ def test_warm_start_multitask_lasso():
828828
clf2 = MultiTaskLasso(alpha=0.1, max_iter=10)
829829
ignore_warnings(clf2.fit)(X, Y)
830830
assert_array_almost_equal(clf2.coef_, clf.coef_)
831+
832+
833+
@pytest.mark.parametrize('klass, n_classes, kwargs',
834+
[(Lasso, 1, dict(precompute=True)),
835+
(Lasso, 1, dict(precompute=False)),
836+
(MultiTaskLasso, 2, dict()),
837+
(MultiTaskLasso, 2, dict())])
838+
def test_enet_coordinate_descent(klass, n_classes, kwargs):
839+
"""Test that a warning is issued if model does not converge"""
840+
clf = klass(max_iter=2, **kwargs)
841+
n_samples = 5
842+
n_features = 2
843+
X = np.ones((n_samples, n_features)) * 1e50
844+
y = np.ones((n_samples, n_classes))
845+
if klass == Lasso:
846+
y = y.ravel()
847+
assert_warns(ConvergenceWarning, clf.fit, X, y)

sklearn/linear_model/tests/test_sparse_coordinate_descent.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from sklearn.utils.testing import assert_greater
1010
from sklearn.utils.testing import ignore_warnings
11+
from sklearn.utils.testing import assert_warns
12+
from sklearn.exceptions import ConvergenceWarning
1113

1214
from sklearn.linear_model.coordinate_descent import (Lasso, ElasticNet,
1315
LassoCV, ElasticNetCV)
@@ -290,3 +292,13 @@ def test_same_multiple_output_sparse_dense():
290292
predict_sparse = l_sp.predict(sample_sparse)
291293

292294
assert_array_almost_equal(predict_sparse, predict_dense)
295+
296+
297+
def test_sparse_enet_coordinate_descent():
298+
"""Test that a warning is issued if model does not converge"""
299+
clf = Lasso(max_iter=2)
300+
n_samples = 5
301+
n_features = 2
302+
X = sp.csc_matrix((n_samples, n_features)) * 1e50
303+
y = np.ones(n_samples)
304+
assert_warns(ConvergenceWarning, clf.fit, X, y)

0 commit comments

Comments
 (0)