Showing posts with label Explainability. Show all posts
Showing posts with label Explainability. Show all posts

Thursday, May 18, 2023

Using LIME with Vision Transformers (VIT)

LIME (Local Interpretable Model-agnostic Explanations) has always been billed as the "model agnostic" way to apply explainable AI (XAI) to any sort of ML model. However, if you have tried to use it with Vision Transformers, it essentially fails.  In fact, several sources flat-out say it can't be done.

I've found a way to utilize LIME with VITs and will be sharing it in this post.

Along with text explainability libraries, LIME provides a set of functions for image explainability as well.  To work with LIME in your vision project, you will need to add these modules:

from lime import lime_image
from skimage.segmentation import mark_boundaries

Later, after your model is setup and ready, use lime_image like this:

explainer = lime_image.LimeImageExplainer()

Next, with the explainer object call explain_instance, which takes in an image or list of images and a predict function.  Here we are tempted to put in the model.predict function for our Vision Transformer.  That is wrong.  In fact, you must make a helper function that manipulates the data before calling in for the prediction:

explanation = explainer.explain_instance(image_list.astype('double'), 
pred_fn
top_labels=3, hide_color=0, num_samples=1000)

The first parameter is a numpy array of the image(s) I want analyzed.  *Important* - it does not contain an extra batch dimension usually already added.  That's added later in the helper function.

The second parameter is the name of the helper function called prod_fn.  This function allows me to add the extra dimension needed by a Vision Transformer.  Make sure when you call the explain_instance, that the input images in the first parameter do not have that first dimension yet.

Here is my implementation of the pred_fn which will take a single image or a list of them.

def pred_fn(imgs):
tot_probs = []
for img in imgs:
# Add the explanation dimension
exp_img = np.expand_dims(img, axis=0)
# Make the prediction
img_pred, _ = vit_model.predict(exp_img)
# Add the predictions to a list to be returned to LIME
tot_probs.append(img_pred[0])
return tot_probs

The output of the predict function is a 2D list.  We only want the first dimension to be added to the return list.

Finally, the explanation can be used to mask off areas of the original image most salient to the model's output.

temp_1, mask_1 = explanation.get_image_and_mask(
explanation.top_labels[0], positive_only=True
num_features=3, hide_rest=True)
temp_2, mask_2 = explanation.get_image_and_mask(
explanation.top_labels[0], positive_only=False
num_features=3, hide_rest=False)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15,15))
ax1.imshow(mark_boundaries(temp_1, mask_1))
ax2.imshow(mark_boundaries(temp_2, mask_2))
ax1.axis('off')
ax2.axis('off')
ax1.set_title('')
ax2.set_title('')
fig.tight_layout()

The process outputs these beautiful image explanations:

Original image - The model correctly labeled the image, "golden_retriever"
The first masked image that shows those areas most salient to the decision.  Note that the black background is due to the parameter "hide_rest" being set to True. 

The second masked image again shows the areas most salient to the decision but this time in context with the rest of the image.

I hope this is helpful in your Vision Transformer explanations!