Skip to content

ILOREM - ValueError: Input contains NaN, infinity or a value too large for dtype('float64'). #8

@CeciPani

Description

@CeciPani

I'm using ILOREM for images to explain the prediction of a model built in PyTorch.

Following the notebook's example, I defined the following segmentation function:

def segmentation_fn(image):
    return quickshift(image, kernel_size=5, max_dist=10)

Generating 7 superpixels on my image.

The black box predict function is the following:

def batch_predict(images):
    
    transform = transforms.ToTensor() #this already normalize RBG in range 0-1
    batch = torch.stack(tuple(transform(i) for i in images), dim=0)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batch = batch.to(device)
    output = model(batch)  # Forward pass
    output = torch.nn.functional.softmax(output, dim=1)
    pred = torch.argmax(output,1)
    
    return pred.numpy().reshape(-1,1)

Finally I tried to explain an image following the provided example:

from externals.LOREM.ilorem import ILOREM

# Function to convert from rgb2gray since the one implemented in skimage gives problems
def rgb2gray(rgb):
    return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])

#The predict function needs to take as input a list of RGB image 
#and return an array of class indices of shape (-1,1)

# Create the explainer
explainer = ILOREM(bb_predict = batch_predict, 
                   neigh_type='lime',
                   class_name='class', 
                   class_values = [0,1], 
                   segmentation_fn=segmentation_fn,
                   verbose=True)

exp = explainer.explain_instance(np.array(img), 
                                 num_samples=500, 
                                 use_weights=True, 
                                 metric='cosine')

Which gives me the following error:

generating neighborhood - lime
synthetic neighborhood class counts {1: 500}
learning local decision tree

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-78-6011bfd26369> in <module>
     16                    verbose=True)
     17 
---> 18 exp = explainer.explain_instance(np.array(img), 
     19                                  num_samples=500,
     20                                  use_weights=True,

~/venvs/XAI-Lib-venv/XAI-Lib/externals/LOREM/ilorem.py in explain_instance(self, img, num_samples, use_weights, metric, hide_color)
     90             print('learning local decision tree')
     91 
---> 92         dt = learn_local_decision_tree(Z, Yb, weights, self.class_values)
     93         Yc = dt.predict(Z)
     94 

~/venvs/XAI-Lib-venv/XAI-Lib/externals/LOREM/decision_tree.py in learn_local_decision_tree(Z, Yb, weights, class_values, multi_label, one_vs_rest, cv, prune_tree)
     30         prune_duplicate_leaves(dt)
     31     else:
---> 32         dt.fit(Z, Yb, sample_weight=weights)
     33 
     34     return dt

~/.local/lib/python3.8/site-packages/sklearn/tree/_classes.py in fit(self, X, y, sample_weight, check_input, X_idx_sorted)
    888         """
    889 
--> 890         super().fit(
    891             X, y,
    892             sample_weight=sample_weight,

~/.local/lib/python3.8/site-packages/sklearn/tree/_classes.py in fit(self, X, y, sample_weight, check_input, X_idx_sorted)
    286 
    287         if sample_weight is not None:
--> 288             sample_weight = _check_sample_weight(sample_weight, X, DOUBLE)
    289 
    290         if expanded_class_weight is not None:

~/.local/lib/python3.8/site-packages/sklearn/utils/validation.py in _check_sample_weight(sample_weight, X, dtype)
   1292         if dtype is None:
   1293             dtype = [np.float64, np.float32]
-> 1294         sample_weight = check_array(
   1295             sample_weight, accept_sparse=False, ensure_2d=False, dtype=dtype,
   1296             order="C"

~/.local/lib/python3.8/site-packages/sklearn/utils/validation.py in inner_f(*args, **kwargs)
     70                           FutureWarning)
     71         kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
---> 72         return f(**kwargs)
     73     return inner_f
     74 

~/.local/lib/python3.8/site-packages/sklearn/utils/validation.py in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator)
    642 
    643         if force_all_finite:
--> 644             _assert_all_finite(array,
    645                                allow_nan=force_all_finite == 'allow-nan')
    646 

~/.local/lib/python3.8/site-packages/sklearn/utils/validation.py in _assert_all_finite(X, allow_nan, msg_dtype)
     94                 not allow_nan and not np.isfinite(X).all()):
     95             type_err = 'infinity' if allow_nan else 'NaN, infinity'
---> 96             raise ValueError(
     97                     msg_err.format
     98                     (type_err,

ValueError: Input contains NaN, infinity or a value too large for dtype('float64').

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions