@@ -1051,20 +1051,25 @@ def check_clustering(name, clusterer_orig):
10511051 assert_in (pred .dtype , [np .dtype ('int32' ), np .dtype ('int64' )])
10521052 assert_in (pred2 .dtype , [np .dtype ('int32' ), np .dtype ('int64' )])
10531053
1054+ # Add noise to X to test the possible values of the labels
1055+ rng = np .random .RandomState (7 )
1056+ X_noise = np .concatenate ([X , rng .uniform (low = - 3 , high = 3 , size = (5 , 2 ))])
1057+ labels = clusterer .fit_predict (X_noise )
1058+
10541059 # There should be at least one sample in every cluster. Equivalently
10551060 # labels_ should contain all the consecutive values between its
10561061 # min and its max.
1057- pred_sorted = np .unique (pred )
1058- assert_array_equal (pred_sorted , np .arange (pred_sorted [0 ],
1059- pred_sorted [- 1 ] + 1 ))
1062+ labels_sorted = np .unique (labels )
1063+ assert_array_equal (labels_sorted , np .arange (labels_sorted [0 ],
1064+ labels_sorted [- 1 ] + 1 ))
10601065
1061- # labels_ should be greater than -1
1062- assert_greater_equal ( pred_sorted [0 ], - 1 )
1063- # labels_ should be less than n_clusters - 1
1066+ # Labels are expected to start at 0 (no noise) or -1 (if noise)
1067+ assert_true ( labels_sorted [0 ] in [ 0 , - 1 ] )
1068+ # Labels should be less than n_clusters - 1
10641069 if hasattr (clusterer , 'n_clusters' ):
10651070 n_clusters = getattr (clusterer , 'n_clusters' )
1066- assert_greater_equal (n_clusters - 1 , pred_sorted [- 1 ])
1067- # else labels_ should be less than max(labels_) which is necessarily true
1071+ assert_greater_equal (n_clusters - 1 , labels_sorted [- 1 ])
1072+ # else labels should be less than max(labels_) which is necessarily true
10681073
10691074
10701075@ignore_warnings (category = DeprecationWarning )
0 commit comments