Skip to content

Commit 2d7567f

Browse files
massichamueller
authored andcommitted
Rework warning in check_array when silent convert string to float (scikit-learn#11577)
1 parent 4b3e49f commit 2d7567f

File tree

2 files changed

+14
-35
lines changed

2 files changed

+14
-35
lines changed

sklearn/utils/tests/test_validation.py

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -291,40 +291,17 @@ def test_check_array():
291291
assert_true(isinstance(result, np.ndarray))
292292

293293
# deprecation warning if string-like array with dtype="numeric"
294-
X_str = [['a', 'b'], ['c', 'd']]
295-
assert_warns_message(
296-
FutureWarning,
297-
"arrays of strings will be interpreted as decimal numbers if "
298-
"parameter 'dtype' is 'numeric'. It is recommended that you convert "
299-
"the array to type np.float64 before passing it to check_array.",
300-
check_array, X_str, "numeric")
301-
assert_warns_message(
302-
FutureWarning,
303-
"arrays of strings will be interpreted as decimal numbers if "
304-
"parameter 'dtype' is 'numeric'. It is recommended that you convert "
305-
"the array to type np.float64 before passing it to check_array.",
306-
check_array, np.array(X_str, dtype='U'), "numeric")
307-
assert_warns_message(
308-
FutureWarning,
309-
"arrays of strings will be interpreted as decimal numbers if "
310-
"parameter 'dtype' is 'numeric'. It is recommended that you convert "
311-
"the array to type np.float64 before passing it to check_array.",
312-
check_array, np.array(X_str, dtype='S'), "numeric")
294+
expected_warn_regex = r"converted to decimal numbers if dtype='numeric'"
295+
X_str = [['11', '12'], ['13', 'xx']]
296+
for X in [X_str, np.array(X_str, dtype='U'), np.array(X_str, dtype='S')]:
297+
with pytest.warns(FutureWarning, match=expected_warn_regex):
298+
check_array(X, dtype="numeric")
313299

314300
# deprecation warning if byte-like array with dtype="numeric"
315301
X_bytes = [[b'a', b'b'], [b'c', b'd']]
316-
assert_warns_message(
317-
FutureWarning,
318-
"arrays of strings will be interpreted as decimal numbers if "
319-
"parameter 'dtype' is 'numeric'. It is recommended that you convert "
320-
"the array to type np.float64 before passing it to check_array.",
321-
check_array, X_bytes, "numeric")
322-
assert_warns_message(
323-
FutureWarning,
324-
"arrays of strings will be interpreted as decimal numbers if "
325-
"parameter 'dtype' is 'numeric'. It is recommended that you convert "
326-
"the array to type np.float64 before passing it to check_array.",
327-
check_array, np.array(X_bytes, dtype='V1'), "numeric")
302+
for X in [X_bytes, np.array(X_bytes, dtype='V1')]:
303+
with pytest.warns(FutureWarning, match=expected_warn_regex):
304+
check_array(X, dtype="numeric")
328305

329306

330307
def test_check_array_pandas_dtype_object_conversion():

sklearn/utils/validation.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -546,10 +546,12 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
546546
# in the future np.flexible dtypes will be handled like object dtypes
547547
if dtype_numeric and np.issubdtype(array.dtype, np.flexible):
548548
warnings.warn(
549-
"Beginning in version 0.22, arrays of strings will be "
550-
"interpreted as decimal numbers if parameter 'dtype' is "
551-
"'numeric'. It is recommended that you convert the array to "
552-
"type np.float64 before passing it to check_array.",
549+
"Beginning in version 0.22, arrays of bytes/strings will be "
550+
"converted to decimal numbers if dtype='numeric'. "
551+
"It is recommended that you convert the array to "
552+
"a float dtype before using it in scikit-learn, "
553+
"for example by using "
554+
"your_array = your_array.astype(np.float64).",
553555
FutureWarning)
554556

555557
# make sure we actually converted to numeric:

0 commit comments

Comments
 (0)