Source code for ferret.explainers.gradient

import pdb
from typing import Optional, Tuple, Union

import torch
from captum.attr import InputXGradient, IntegratedGradients, Saliency
from cv2 import multiply
import numpy as np

from . import BaseExplainer
from .explanation import Explanation
from .utils import parse_explainer_args


[docs] class GradientExplainer(BaseExplainer): NAME = "Gradient"
[docs] def __init__( self, model, tokenizer, model_helper: Optional[str] = None, multiply_by_inputs: bool = True, **kwargs, ): super().__init__(model, tokenizer, model_helper, **kwargs) self.multiply_by_inputs = multiply_by_inputs if self.multiply_by_inputs: self.NAME += " (x Input)"
def compute_feature_importance( self, text: Union[str, Tuple[str, str]], target: Union[int, str] = 1, target_token: Optional[Union[int, str]] = None, **kwargs, ): def func(input_embeds): outputs = self.helper.model( inputs_embeds=input_embeds, attention_mask=item["attention_mask"] ) logits = self.helper._postprocess_logits( outputs.logits, target_token_pos_idx=target_token_pos_idx ) return logits # Sanity checks # TODO these checks have already been conducted if used within the benchmark class. Remove them here if possible. target_pos_idx = self.helper._check_target(target) target_token_pos_idx = self.helper._check_target_token(text, target_token) text = self.helper._check_sample(text) item = self._tokenize(text) item = {k: v.to(self.device) for k, v in item.items()} input_len = item["attention_mask"].sum().item() dl = ( InputXGradient(func, **self.init_args) if self.multiply_by_inputs else Saliency(func, **self.init_args) ) inputs = self.get_input_embeds(text) attr = dl.attribute(inputs, target=target_pos_idx, **kwargs) attr = attr[0, :input_len, :].detach().cpu() # pool over hidden size attr = attr.sum(-1).numpy() output = Explanation( text=text, tokens=self.get_tokens(text), scores=attr, explainer=self.NAME, helper_type=self.helper.HELPER_TYPE, target_pos_idx=target_pos_idx, target_token_pos_idx=target_token_pos_idx, target=self.helper.model.config.id2label[target_pos_idx], target_token=self.helper.tokenizer.decode( item["input_ids"][0, target_token_pos_idx].item() ) if self.helper.HELPER_TYPE == "token-classification" else None, ) return output
[docs] class IntegratedGradientExplainer(BaseExplainer): NAME = "Integrated Gradient"
[docs] def __init__( self, model, tokenizer, model_helper: Optional[str] = None, multiply_by_inputs: bool = True, **kwargs, ): super().__init__(model, tokenizer, model_helper, **kwargs) self.multiply_by_inputs = multiply_by_inputs if self.multiply_by_inputs: self.NAME += " (x Input)"
def _generate_baselines(self, input_len): ids = ( [self.helper.tokenizer.cls_token_id] + [self.helper.tokenizer.pad_token_id] * (input_len - 2) + [self.helper.tokenizer.sep_token_id] ) embeddings = self.helper._get_input_embeds_from_ids( torch.tensor(ids, device=self.device) ) return embeddings.unsqueeze(0) def compute_feature_importance( self, text: Union[str, Tuple[str, str]], target: Union[int, str] = 1, target_token: Optional[Union[int, str]] = None, show_progress: bool = False, **kwargs, ): # Sanity checks # TODO these checks have already been conducted if used within the benchmark class. Remove them here if possible. target_pos_idx = self.helper._check_target(target) target_token_pos_idx = self.helper._check_target_token(text, target_token) text = self.helper._check_sample(text) def func(input_embeds): attention_mask = torch.ones( *input_embeds.shape[:2], dtype=torch.uint8, device=self.device ) _, logits = self.helper._forward_with_input_embeds( input_embeds, attention_mask, show_progress=show_progress ) logits = self.helper._postprocess_logits( logits, target_token_pos_idx=target_token_pos_idx ) return logits item = self._tokenize(text) input_len = item["attention_mask"].sum().item() dl = IntegratedGradients( func, multiply_by_inputs=self.multiply_by_inputs, **self.init_args ) inputs = self.get_input_embeds(text) baselines = self._generate_baselines(input_len) attr = dl.attribute(inputs, baselines=baselines, target=target_pos_idx, **kwargs) attr = attr[0, :input_len, :].detach().cpu() # pool over hidden size attr = attr.sum(-1).numpy() # norm_attr = self._normalize_input_attributions(attr.detach()) output = Explanation( text=text, tokens=self.get_tokens(text), scores=attr, explainer=self.NAME, helper_type=self.helper.HELPER_TYPE, target_pos_idx=target_pos_idx, target_token_pos_idx=target_token_pos_idx, target=self.helper.model.config.id2label[target_pos_idx], target_token=self.helper.tokenizer.decode( item["input_ids"][0, target_token_pos_idx].item() ) if self.helper.HELPER_TYPE == "token-classification" else None, ) return output