-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Description
I'm using ILOREM for images to explain the prediction of a model built in PyTorch.
Following the notebook's example, I defined the following segmentation function:
def segmentation_fn(image):
return quickshift(image, kernel_size=5, max_dist=10)Generating 7 superpixels on my image.
The black box predict function is the following:
def batch_predict(images):
transform = transforms.ToTensor() #this already normalize RBG in range 0-1
batch = torch.stack(tuple(transform(i) for i in images), dim=0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch = batch.to(device)
output = model(batch) # Forward pass
output = torch.nn.functional.softmax(output, dim=1)
pred = torch.argmax(output,1)
return pred.numpy().reshape(-1,1)Finally I tried to explain an image following the provided example:
from externals.LOREM.ilorem import ILOREM
# Function to convert from rgb2gray since the one implemented in skimage gives problems
def rgb2gray(rgb):
return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])
#The predict function needs to take as input a list of RGB image
#and return an array of class indices of shape (-1,1)
# Create the explainer
explainer = ILOREM(bb_predict = batch_predict,
neigh_type='lime',
class_name='class',
class_values = [0,1],
segmentation_fn=segmentation_fn,
verbose=True)
exp = explainer.explain_instance(np.array(img),
num_samples=500,
use_weights=True,
metric='cosine')Which gives me the following error:
generating neighborhood - lime
synthetic neighborhood class counts {1: 500}
learning local decision tree
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-78-6011bfd26369> in <module>
16 verbose=True)
17
---> 18 exp = explainer.explain_instance(np.array(img),
19 num_samples=500,
20 use_weights=True,
~/venvs/XAI-Lib-venv/XAI-Lib/externals/LOREM/ilorem.py in explain_instance(self, img, num_samples, use_weights, metric, hide_color)
90 print('learning local decision tree')
91
---> 92 dt = learn_local_decision_tree(Z, Yb, weights, self.class_values)
93 Yc = dt.predict(Z)
94
~/venvs/XAI-Lib-venv/XAI-Lib/externals/LOREM/decision_tree.py in learn_local_decision_tree(Z, Yb, weights, class_values, multi_label, one_vs_rest, cv, prune_tree)
30 prune_duplicate_leaves(dt)
31 else:
---> 32 dt.fit(Z, Yb, sample_weight=weights)
33
34 return dt
~/.local/lib/python3.8/site-packages/sklearn/tree/_classes.py in fit(self, X, y, sample_weight, check_input, X_idx_sorted)
888 """
889
--> 890 super().fit(
891 X, y,
892 sample_weight=sample_weight,
~/.local/lib/python3.8/site-packages/sklearn/tree/_classes.py in fit(self, X, y, sample_weight, check_input, X_idx_sorted)
286
287 if sample_weight is not None:
--> 288 sample_weight = _check_sample_weight(sample_weight, X, DOUBLE)
289
290 if expanded_class_weight is not None:
~/.local/lib/python3.8/site-packages/sklearn/utils/validation.py in _check_sample_weight(sample_weight, X, dtype)
1292 if dtype is None:
1293 dtype = [np.float64, np.float32]
-> 1294 sample_weight = check_array(
1295 sample_weight, accept_sparse=False, ensure_2d=False, dtype=dtype,
1296 order="C"
~/.local/lib/python3.8/site-packages/sklearn/utils/validation.py in inner_f(*args, **kwargs)
70 FutureWarning)
71 kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
---> 72 return f(**kwargs)
73 return inner_f
74
~/.local/lib/python3.8/site-packages/sklearn/utils/validation.py in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator)
642
643 if force_all_finite:
--> 644 _assert_all_finite(array,
645 allow_nan=force_all_finite == 'allow-nan')
646
~/.local/lib/python3.8/site-packages/sklearn/utils/validation.py in _assert_all_finite(X, allow_nan, msg_dtype)
94 not allow_nan and not np.isfinite(X).all()):
95 type_err = 'infinity' if allow_nan else 'NaN, infinity'
---> 96 raise ValueError(
97 msg_err.format
98 (type_err,
ValueError: Input contains NaN, infinity or a value too large for dtype('float64').
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels