Source code for ferret.explainers.shap
import pdb
from typing import Dict, Optional, Text, Union
import logging
import numpy as np
import shap
from shap.maskers import Text as TextMasker
from . import BaseExplainer
from .explanation import Explanation
from .utils import parse_explainer_args
[docs]
class SHAPExplainer(BaseExplainer):
NAME = "Partition SHAP"
[docs]
def __init__(
self,
model,
tokenizer,
model_helper: Optional[str] = None,
silent: bool = True,
algorithm: str = "partition",
seed: int = 42,
**kwargs,
):
super().__init__(model, tokenizer, model_helper, **kwargs)
# Initializing SHAP-specific arguments
self.init_args["silent"] = silent
self.init_args["algorithm"] = algorithm
self.init_args["seed"] = seed
def compute_feature_importance(
self,
text,
target: Union[int, Text] = 1,
target_token: Optional[Union[int, Text]] = None,
**kwargs,
):
# sanity checks
target_pos_idx = self.helper._check_target(target)
target_token_pos_idx = self.helper._check_target_token(text, target_token)
text = self.helper._check_sample(text)
# Removing 'target_option' if passed as it's not relevant here
if 'target_option' in kwargs:
logging.warning("The 'target_option' argument is not used in SHAPExplainer and will be removed.")
kwargs.pop('target_option')
# Function to compute logits for SHAP explainer
def func(texts: np.array):
_, logits = self.helper._forward(texts.tolist())
# Adjust logits based on the target token position
logits = self.helper._postprocess_logits(
logits, target_token_pos_idx=target_token_pos_idx
)
return logits.softmax(-1).cpu().numpy()
masker = TextMasker(self.tokenizer)
explainer_partition = shap.Explainer(model=func, masker=masker, **self.init_args)
shap_values = explainer_partition(text, **kwargs)
attr = shap_values.values[0][:, target_pos_idx]
# Tokenize the text for token-level explanation
item = self._tokenize(text, return_special_tokens_mask=True)
token_ids = item['input_ids'][0].tolist()
token_scores = np.zeros_like(token_ids, dtype=float)
# Assigning SHAP values to tokens, ignoring special tokens
for i, (shap_value, is_special_token) in enumerate(zip(attr, item['special_tokens_mask'][0])):
if not is_special_token:
token_scores[i] = shap_value
output = Explanation(
text=text,
tokens=self.get_tokens(text),
scores=token_scores,
explainer=self.NAME,
helper_type=self.helper.HELPER_TYPE,
target_pos_idx=target_pos_idx,
target_token_pos_idx=target_token_pos_idx,
target=self.helper.model.config.id2label[target_pos_idx],
target_token=self.helper.tokenizer.decode(
item["input_ids"][0, target_token_pos_idx].item()
)
if self.helper.HELPER_TYPE == "token-classification"
else None,
)
return output