Explain a prediction using LIME


This article explains LIME in an extremely simple manner. A basic knowledge in ML and NLP is needed to understand the article. In addition to the explanation, I have implemented the said topic in the following colab notebook. Please refrain from directly copying content from the post. In case you want to use an excerpt or any original image, mention the link to the post


Consider this. You are in a team of data scientists developing a recommendation system for targeted ads. Collect people's data through various sources, and train a model to target ads to people who are more likely to buy your stuff. The business objective is clear - to earn more money. As a data scientist, you have your set of metrics to validate the model performance.

While experimenting you find that a 100 million parameter SOTA model gives you impressive results. Everyone is happy. It works well in production too!

What if somehow how find that the model targets people who have Bipolar Disorder and are entering a maniac phase and hence are about to go on a spending spree?

Credits: Google Images

Credits: Google Images

Consider another scenario. You have to classify images into pandas and polar bears. You'd expect the model to learn panda specific features like black furs. Your trained model gives the best accuracy and the best F1 score there is.

Credits: Google Images

Credits: Google Images


However, your model misclassifies this photo

Credits: Google Images

Credits: Google Images

Most of your panda pictures happen to have a forest background. The model saw snow and classified the image as polar bear. In short, you just build an amazing snow detector. Not an animal classifier.

What's the point?

As a Data Scientist, you probably can't inspect every aspect of your data. And modern Machine Learning models are not easily interpretable. Try asking your model Why should I trust you?

Explanation vs Interpretation

Model explanation and model interpretation are two different concepts. You can interpret the weights of a linear or logistic regression. You can't interpret weights of a heavy neural network.

By Explanation, we mean presenting textual or visual artifacts that provide qualitative understanding of the relationship between instance's components (words in text or patches in images) and the , model's prediction.

For example if predicting the sentiment of a sentence like

Amazing place to visit. Enjoyed the rides.

as positive, the model must highlight the words Amazing and Enjoyed.

What should an ideal explainer look like?

Explanation becomes tough in high dimension spaces. For example, in images, it is hard to examine each and every pixel. Or when working on documents with thousands of words, it is hard to examine each and every word. The explainer should be model agnostic. It is impractical to design an explainer for every model. Lastly, the explainer should be faithful. This means if we see some absurd explanation for an instance, we should be confident enough in the explainer to say that the model is at fault.

LIME - Local Interpretable Model Agnostic Explanation

Mouthful? I knoww!!

LIME is a framework that helps generate visual cues for a model’s explanation. We will look at LIME using 3 ideas. By the end of it, you will have gained all the sufficient insight to be able to use LIME.

Idea 1: Train a simple model to explain your model

Let’s say you want to explain predictions of your classifier. We call it main model (an image classifier or text classifier for instance). The main model is hard to interpret. Probably a heavy neural network. We want to train another model called the explainer model that is highly interpretable (linear or tree based model). We use the explainer model to explain the prediction.

Original Image

Original Image

Now pay attention to the specifics. You have a dataset with inputs [X1, X2, … Xn]. You want to explain prediction of your main model for a particular point Xk. LIME trains the explainer model specific to Xk. This means, the same explainer model won’t explain any other point.


Idea 2: The explainer model is local

At this point, you might be wondering what dataset we train the explainer model on? Remember that the model is just supposed to explain the point Xk. Hence, it doesn’t make sense to train the model on all data.

Original Image

Original Image

We build a synthetic dataset to train the explainer model on. Consider a text classification task with Bag-Of-Words representation of sentences. The dataset has vocabulary of seven words - {Mary, had, a, little, lamb, cake, ice}. The sentence Mary had a little lamb would be represented as -

Original Image

Original Image

Let’s say this vector [1, 1, 1, 0, 0, 1, 1] is our Xk. To generate synthetic dataset, we consider all the positions with non-zero values. We have five such positions. We sample from these positions and generate vectors where these positions retain the value in the original vector and rest of the positions assume zero. For example -

Original Image

Original Image

Notice the positions in these two synthetic points where the value is non-zero and compare it with the original vector.

We have generated the synthetic inputs. What about the labels? We use the main model to generate labels for this data. Think of this as injecting main model’s behavior in the process. Main model’s prediction on this synthetic data reflects what model thinks about these data points. After all, we are to explain main model’s prediction on Xk.

Idea 3: Local is more important

In any ML task, we tak a batch or a mini-batch, compute the loss for every point and add them. This means we treat each input point equally. In our case, we don’t really care about loss from points that are not local to Xk.

Original Image

Original Image

In the above image, the larger the cross, greater its weight is in the loss calculation during training of the explainer model. Also, the larger the cross, more closer is the point to Xk.

Sounds good. What about implementation?

I really hope you understood the explanation. For more deep dive into the mathematics, refer to the original paper. The authors of LIME have created an easy to use Python package called lime. It provides functionality to explain tabular, text and image data. I have created an easy to run colab notebook to experiment with lime on sentiment classification data. The notebook uses LIME for traditional ML algorithms as well as LSTM based models using keras. Here is one highlight of what you can expect in the explanations.

Original Image

Original Image


The image is an explanation of a sentence used in sentiment classification task. The model is kinda able to distinguish the positive and negative words. More experiments in the colab notebook.

Other references

  • A video on LIME.

  • A TED talk on why ML models need to be explained.

  • An argument against explainable AI.

Previous
Previous

Graph Convolutional Networks for dummies