Skip to content

Commit 8472350

Browse files
glemaitreamueller
authored andcommitted
[MRG + 1] ENH add check_inverse in FunctionTransformer (scikit-learn#9399)
* EHN add check_inverse in FunctionTransformer * Add whats new entry and short narrative doc * Sparse support * better handle sparse data * Address andreas comments * PEP8 * Absolute tolerance default * DOC fix docstring * Remove random state and make check_inverse deterministic * FIX remove random_state from init * PEP8 * DOC motivation for the inverse * make check_inverse=True default with a warning * PEP8 * FIX get back X from check_array * Andread comments * Update whats new * remove blank line * joel s comments * no check if one of forward or inverse not provided * DOC fixes and example of filterwarnings * DOC fix warningfiltering * DOC fix merge error git
1 parent 102620f commit 8472350

File tree

4 files changed

+86
-7
lines changed

4 files changed

+86
-7
lines changed

doc/modules/preprocessing.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,15 @@ a transformer that applies a log transformation in a pipeline, do::
610610
array([[ 0. , 0.69314718],
611611
[ 1.09861229, 1.38629436]])
612612

613+
You can ensure that ``func`` and ``inverse_func`` are the inverse of each other
614+
by setting ``check_inverse=True`` and calling ``fit`` before
615+
``transform``. Please note that a warning is raised and can be turned into an
616+
error with a ``filterwarnings``::
617+
618+
>>> import warnings
619+
>>> warnings.filterwarnings("error", message=".*check_inverse*.",
620+
... category=UserWarning, append=False)
621+
613622
For a full code example that demonstrates using a :class:`FunctionTransformer`
614623
to do custom feature selection,
615624
see :ref:`sphx_glr_auto_examples_preprocessing_plot_function_transformer.py`

doc/whats_new/v0.20.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ Classifiers and regressors
4040
- Added :class:`naive_bayes.ComplementNB`, which implements the Complement
4141
Naive Bayes classifier described in Rennie et al. (2003).
4242
By :user:`Michael A. Alcorn <airalcorn2>`.
43-
43+
4444
Model evaluation
4545

4646
- Added the :func:`metrics.balanced_accuracy` metric and a corresponding
@@ -65,6 +65,11 @@ Classifiers and regressors
6565
:class:`sklearn.naive_bayes.GaussianNB` to give a precise control over
6666
variances calculation. :issue:`9681` by :user:`Dmitry Mottl <Mottl>`.
6767

68+
- A parameter ``check_inverse`` was added to :class:`FunctionTransformer`
69+
to ensure that ``func`` and ``inverse_func`` are the inverse of each
70+
other.
71+
:issue:`9399` by :user:`Guillaume Lemaitre <glemaitre>`.
72+
6873
Model evaluation and meta-estimators
6974

7075
- A scorer based on :func:`metrics.brier_score_loss` is also available.

sklearn/preprocessing/_function_transformer.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ..base import BaseEstimator, TransformerMixin
44
from ..utils import check_array
5+
from ..utils.testing import assert_allclose_dense_sparse
56
from ..externals.six import string_types
67

78

@@ -19,8 +20,6 @@ class FunctionTransformer(BaseEstimator, TransformerMixin):
1920
function. This is useful for stateless transformations such as taking the
2021
log of frequencies, doing custom scaling, etc.
2122
22-
A FunctionTransformer will not do any checks on its function's output.
23-
2423
Note: If a lambda is used as the function, then the resulting
2524
transformer will not be pickleable.
2625
@@ -59,6 +58,13 @@ class FunctionTransformer(BaseEstimator, TransformerMixin):
5958
6059
.. deprecated::0.19
6160
61+
check_inverse : bool, default=True
62+
Whether to check that or ``func`` followed by ``inverse_func`` leads to
63+
the original inputs. It can be used for a sanity check, raising a
64+
warning when the condition is not fulfilled.
65+
66+
.. versionadded:: 0.20
67+
6268
kw_args : dict, optional
6369
Dictionary of additional keyword arguments to pass to func.
6470
@@ -67,16 +73,30 @@ class FunctionTransformer(BaseEstimator, TransformerMixin):
6773
6874
"""
6975
def __init__(self, func=None, inverse_func=None, validate=True,
70-
accept_sparse=False, pass_y='deprecated',
76+
accept_sparse=False, pass_y='deprecated', check_inverse=True,
7177
kw_args=None, inv_kw_args=None):
7278
self.func = func
7379
self.inverse_func = inverse_func
7480
self.validate = validate
7581
self.accept_sparse = accept_sparse
7682
self.pass_y = pass_y
83+
self.check_inverse = check_inverse
7784
self.kw_args = kw_args
7885
self.inv_kw_args = inv_kw_args
7986

87+
def _check_inverse_transform(self, X):
88+
"""Check that func and inverse_func are the inverse."""
89+
idx_selected = slice(None, None, max(1, X.shape[0] // 100))
90+
try:
91+
assert_allclose_dense_sparse(
92+
X[idx_selected],
93+
self.inverse_transform(self.transform(X[idx_selected])))
94+
except AssertionError:
95+
warnings.warn("The provided functions are not strictly"
96+
" inverse of each other. If you are sure you"
97+
" want to proceed regardless, set"
98+
" 'check_inverse=False'.", UserWarning)
99+
80100
def fit(self, X, y=None):
81101
"""Fit transformer by checking X.
82102
@@ -92,7 +112,10 @@ def fit(self, X, y=None):
92112
self
93113
"""
94114
if self.validate:
95-
check_array(X, self.accept_sparse)
115+
X = check_array(X, self.accept_sparse)
116+
if (self.check_inverse and not (self.func is None or
117+
self.inverse_func is None)):
118+
self._check_inverse_transform(X)
96119
return self
97120

98121
def transform(self, X, y='deprecated'):

sklearn/preprocessing/tests/test_function_transformer.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import numpy as np
2+
from scipy import sparse
23

34
from sklearn.preprocessing import FunctionTransformer
4-
from sklearn.utils.testing import assert_equal, assert_array_equal
5-
from sklearn.utils.testing import assert_warns_message
5+
from sklearn.utils.testing import (assert_equal, assert_array_equal,
6+
assert_allclose_dense_sparse)
7+
from sklearn.utils.testing import assert_warns_message, assert_no_warnings
68

79

810
def _make_func(args_store, kwargs_store, func=lambda X, *a, **k: X):
@@ -126,3 +128,43 @@ def test_inverse_transform():
126128
F.inverse_transform(F.transform(X)),
127129
np.around(np.sqrt(X), decimals=3),
128130
)
131+
132+
133+
def test_check_inverse():
134+
X_dense = np.array([1, 4, 9, 16], dtype=np.float64).reshape((2, 2))
135+
136+
X_list = [X_dense,
137+
sparse.csr_matrix(X_dense),
138+
sparse.csc_matrix(X_dense)]
139+
140+
for X in X_list:
141+
if sparse.issparse(X):
142+
accept_sparse = True
143+
else:
144+
accept_sparse = False
145+
trans = FunctionTransformer(func=np.sqrt,
146+
inverse_func=np.around,
147+
accept_sparse=accept_sparse,
148+
check_inverse=True)
149+
assert_warns_message(UserWarning,
150+
"The provided functions are not strictly"
151+
" inverse of each other. If you are sure you"
152+
" want to proceed regardless, set"
153+
" 'check_inverse=False'.",
154+
trans.fit, X)
155+
156+
trans = FunctionTransformer(func=np.expm1,
157+
inverse_func=np.log1p,
158+
accept_sparse=accept_sparse,
159+
check_inverse=True)
160+
Xt = assert_no_warnings(trans.fit_transform, X)
161+
assert_allclose_dense_sparse(X, trans.inverse_transform(Xt))
162+
163+
# check that we don't check inverse when one of the func or inverse is not
164+
# provided.
165+
trans = FunctionTransformer(func=np.expm1, inverse_func=None,
166+
check_inverse=True)
167+
assert_no_warnings(trans.fit, X_dense)
168+
trans = FunctionTransformer(func=None, inverse_func=np.expm1,
169+
check_inverse=True)
170+
assert_no_warnings(trans.fit, X_dense)

0 commit comments

Comments
 (0)