From 558cb1e9d35c2551e1b4ad3add497124b0a48919 Mon Sep 17 00:00:00 2001 From: James Jin Date: Fri, 12 Jul 2024 23:26:55 -0400 Subject: [PATCH] feat(LM): attributions for SequenceClassification --- src/ecco/attribution.py | 3 +- src/ecco/lm.py | 114 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+), 1 deletion(-) diff --git a/src/ecco/attribution.py b/src/ecco/attribution.py index 198acb6..bf1b8aa 100644 --- a/src/ecco/attribution.py +++ b/src/ecco/attribution.py @@ -64,7 +64,8 @@ def model_forward(input_: torch.Tensor, decoder_: torch.Tensor, model, extra_for output = model(inputs_embeds=input_, decoder_inputs_embeds=decoder_, **extra_forward_args) else: output = model(inputs_embeds=input_, **extra_forward_args) - return F.softmax(output.logits[:, -1, :], dim=-1) + # SequenceClassfication models only output 1 set of logits per batch, resulting in 2 dimensional outputs + return F.softmax(output.logits[:, -1, :] if len(output.logits.shape) == 3 else output.logits, dim=-1) def normalize_attributes(attributes: torch.Tensor) -> torch.Tensor: # attributes has shape (batch, sequence size, embedding dim) diff --git a/src/ecco/lm.py b/src/ecco/lm.py index 97d3330..df55b58 100644 --- a/src/ecco/lm.py +++ b/src/ecco/lm.py @@ -504,6 +504,120 @@ def __call__(self, input_tokens: torch.Tensor): 'device': self.device, 'config': self.model_config}) + def classify(self, input_str: str, attribution: Optional[List[str]] = []): + """ + For the input text as a whole, select a label from the set of labels defined within the model. + Assumes the model being used is of the SequenceClassification variant. + Args: + input_str: Input prompt. # TODO: accept batch of input strings + attribution: List of attribution methods to be calculated. By default, it does not calculate anything. + """ + input_tokenized_info = self.tokenizer(input_str, return_tensors="pt") + input_tokenized_info = self.to(input_tokenized_info) + input_ids, attention_mask = input_tokenized_info['input_ids'], input_tokenized_info['attention_mask'] + n_input_tokens = len(input_ids[0]) + + if self.verbose: + viz_id = self.display_input_sequence(input_ids[0]) + + # Perform the classification process itself + encoder_input_embeds, _ = self._get_embeddings(input_ids) + self.attributions = defaultdict(list) # reset attributions dict + forward_kwargs = { + 'inputs_embeds': encoder_input_embeds, + 'return_dict': True, + 'attention_mask': attention_mask + } + self._attach_hooks(self.model) + model_outputs = self.model(**forward_kwargs) + prediction_id = torch.argmax(model_outputs.logits[0]) + + # Get primary attributions for deduced label + self._analyze_token( + encoder_input_embeds=encoder_input_embeds, + encoder_attention_mask=attention_mask, + decoder_input_embeds=None, + attribution_flags=attribution, + prediction_id=prediction_id + ) + + # Store the predicted class id together with input ids + input_ids = torch.cat([input_ids, torch.tensor([[prediction_id]], device=input_ids.device)], dim=-1) + + if self.verbose: + # Output the predicted label as tokens + label_ids = self.tokenizer.encode(self.model.config.to_dict()["id2label"][prediction_id.item()], + add_special_tokens=False, return_tensors="np").squeeze(0) + for offset, id in enumerate(label_ids): # may need to use multiple tokens to represent label name + self.display_token( + viz_id, + id.item(), + n_input_tokens + offset + ) + + # Get encoder/decoder hidden states + embedding_states = None + for attributes in ["hidden_states", "encoder_hidden_states", "decoder_hidden_states"]: + out_attr = getattr(model_outputs, attributes, None) + if out_attr is not None: + + hs_list = [] + for idx, layer_hs in enumerate(out_attr): + # in Hugging Face Transformers v4, there's an extra index for batch + if len(layer_hs.shape) == 3: # If there's a batch dimension, pick the first one + hs = layer_hs.cpu().detach()[0].unsqueeze(0) # Adding a dimension to concat to later + # Earlier versions are only 2 dimensional + # But also, in v4, for GPT2, all except the last one would have 3 dims, the last layer + # would only have two dims + else: + hs = layer_hs.cpu().detach().unsqueeze(0) + + # First hidden state is the embedding layer, skip it + if idx == 0: + embedding_states = hs + else: + hs_list.append(hs) + + hidden_states = torch.cat(hs_list, dim=0) + setattr(model_outputs, attributes, hidden_states) + + # Pass 'hidden_states' to 'decoder_hidden_states' + if getattr(model_outputs, "hidden_states", None) is not None: + assert getattr(model_outputs, "encoder_hidden_states", None) is None \ + and getattr(model_outputs, "decoder_hidden_states", None) is None, \ + "Not expected to have encoder_hidden_states/decoder_hidden_states with 'hidden_states'" + setattr(model_outputs, "decoder_hidden_states", model_outputs.hidden_states) + + encoder_hidden_states = getattr(model_outputs, "encoder_hidden_states", None) + decoder_hidden_states = getattr(model_outputs, "hidden_states", getattr(model_outputs, "decoder_hidden_states", None)) + + # Turn activations from dict to a proper array + activations_dict = self._all_activations_dict + for layer_type, activations in activations_dict.items(): + self.activations[layer_type] = activations_dict_to_array(activations) + + all_token_ids = input_ids[0] # FIXME: make OutputSeq.primary_attributions turn the prediction id into a label and not token + tokens = self.tokenizer.convert_ids_to_tokens(all_token_ids) + attributions = self.attributions + attn = getattr(model_outputs, "attentions", None) + + return OutputSeq(**{'tokenizer': self.tokenizer, + 'token_ids': all_token_ids.unsqueeze(0), # Add a batch dimension + 'n_input_tokens': n_input_tokens, + 'output_text': self.tokenizer.decode(all_token_ids), + 'tokens': [tokens], # Add a batch dimension + 'encoder_hidden_states': encoder_hidden_states, + 'decoder_hidden_states': decoder_hidden_states, + 'embedding_states': embedding_states, + 'attention': attn, + 'attribution': attributions, + 'activations': self.activations, + 'collect_activations_layer_nums': self.collect_activations_layer_nums, + 'model_type': self.model_type, + 'device': self.device, + 'config': self.model_config}) + + def _get_embeddings(self, input_ids) -> Tuple[torch.FloatTensor, torch.FloatTensor]: """ Get token embeddings and one-hot vector into vocab. It's done via matrix multiplication