Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(LM): attributions for SequenceClassification #116

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/ecco/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def model_forward(input_: torch.Tensor, decoder_: torch.Tensor, model, extra_for
output = model(inputs_embeds=input_, decoder_inputs_embeds=decoder_, **extra_forward_args)
else:
output = model(inputs_embeds=input_, **extra_forward_args)
return F.softmax(output.logits[:, -1, :], dim=-1)
# SequenceClassfication models only output 1 set of logits per batch, resulting in 2 dimensional outputs
return F.softmax(output.logits[:, -1, :] if len(output.logits.shape) == 3 else output.logits, dim=-1)

def normalize_attributes(attributes: torch.Tensor) -> torch.Tensor:
# attributes has shape (batch, sequence size, embedding dim)
Expand Down
114 changes: 114 additions & 0 deletions src/ecco/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,120 @@ def __call__(self, input_tokens: torch.Tensor):
'device': self.device,
'config': self.model_config})

def classify(self, input_str: str, attribution: Optional[List[str]] = []):
"""
For the input text as a whole, select a label from the set of labels defined within the model.
Assumes the model being used is of the SequenceClassification variant.
Args:
input_str: Input prompt. # TODO: accept batch of input strings
attribution: List of attribution methods to be calculated. By default, it does not calculate anything.
"""
input_tokenized_info = self.tokenizer(input_str, return_tensors="pt")
input_tokenized_info = self.to(input_tokenized_info)
input_ids, attention_mask = input_tokenized_info['input_ids'], input_tokenized_info['attention_mask']
n_input_tokens = len(input_ids[0])

if self.verbose:
viz_id = self.display_input_sequence(input_ids[0])

# Perform the classification process itself
encoder_input_embeds, _ = self._get_embeddings(input_ids)
self.attributions = defaultdict(list) # reset attributions dict
forward_kwargs = {
'inputs_embeds': encoder_input_embeds,
'return_dict': True,
'attention_mask': attention_mask
}
self._attach_hooks(self.model)
model_outputs = self.model(**forward_kwargs)
prediction_id = torch.argmax(model_outputs.logits[0])

# Get primary attributions for deduced label
self._analyze_token(
encoder_input_embeds=encoder_input_embeds,
encoder_attention_mask=attention_mask,
decoder_input_embeds=None,
attribution_flags=attribution,
prediction_id=prediction_id
)

# Store the predicted class id together with input ids
input_ids = torch.cat([input_ids, torch.tensor([[prediction_id]], device=input_ids.device)], dim=-1)

if self.verbose:
# Output the predicted label as tokens
label_ids = self.tokenizer.encode(self.model.config.to_dict()["id2label"][prediction_id.item()],
add_special_tokens=False, return_tensors="np").squeeze(0)
for offset, id in enumerate(label_ids): # may need to use multiple tokens to represent label name
self.display_token(
viz_id,
id.item(),
n_input_tokens + offset
)

# Get encoder/decoder hidden states
embedding_states = None
for attributes in ["hidden_states", "encoder_hidden_states", "decoder_hidden_states"]:
out_attr = getattr(model_outputs, attributes, None)
if out_attr is not None:

hs_list = []
for idx, layer_hs in enumerate(out_attr):
# in Hugging Face Transformers v4, there's an extra index for batch
if len(layer_hs.shape) == 3: # If there's a batch dimension, pick the first one
hs = layer_hs.cpu().detach()[0].unsqueeze(0) # Adding a dimension to concat to later
# Earlier versions are only 2 dimensional
# But also, in v4, for GPT2, all except the last one would have 3 dims, the last layer
# would only have two dims
else:
hs = layer_hs.cpu().detach().unsqueeze(0)

# First hidden state is the embedding layer, skip it
if idx == 0:
embedding_states = hs
else:
hs_list.append(hs)

hidden_states = torch.cat(hs_list, dim=0)
setattr(model_outputs, attributes, hidden_states)

# Pass 'hidden_states' to 'decoder_hidden_states'
if getattr(model_outputs, "hidden_states", None) is not None:
assert getattr(model_outputs, "encoder_hidden_states", None) is None \
and getattr(model_outputs, "decoder_hidden_states", None) is None, \
"Not expected to have encoder_hidden_states/decoder_hidden_states with 'hidden_states'"
setattr(model_outputs, "decoder_hidden_states", model_outputs.hidden_states)

encoder_hidden_states = getattr(model_outputs, "encoder_hidden_states", None)
decoder_hidden_states = getattr(model_outputs, "hidden_states", getattr(model_outputs, "decoder_hidden_states", None))

# Turn activations from dict to a proper array
activations_dict = self._all_activations_dict
for layer_type, activations in activations_dict.items():
self.activations[layer_type] = activations_dict_to_array(activations)

all_token_ids = input_ids[0] # FIXME: make OutputSeq.primary_attributions turn the prediction id into a label and not token
tokens = self.tokenizer.convert_ids_to_tokens(all_token_ids)
attributions = self.attributions
attn = getattr(model_outputs, "attentions", None)

return OutputSeq(**{'tokenizer': self.tokenizer,
'token_ids': all_token_ids.unsqueeze(0), # Add a batch dimension
'n_input_tokens': n_input_tokens,
'output_text': self.tokenizer.decode(all_token_ids),
'tokens': [tokens], # Add a batch dimension
'encoder_hidden_states': encoder_hidden_states,
'decoder_hidden_states': decoder_hidden_states,
'embedding_states': embedding_states,
'attention': attn,
'attribution': attributions,
'activations': self.activations,
'collect_activations_layer_nums': self.collect_activations_layer_nums,
'model_type': self.model_type,
'device': self.device,
'config': self.model_config})


def _get_embeddings(self, input_ids) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""
Get token embeddings and one-hot vector into vocab. It's done via matrix multiplication
Expand Down