Skip to content

Commit

Permalink
Merge pull request #59 from JoaoLages/bugfix/issue_#57
Browse files Browse the repository at this point in the history
Bugfix/issue #57
  • Loading branch information
jalammar authored Jan 3, 2022
2 parents c363e14 + 555d063 commit 9aa8639
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions src/ecco/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import random
import torch
import transformers
from transformers import BatchEncoding

import ecco
import numpy as np
from IPython import display as d
from torch.nn import functional as F
from ecco.attribution import compute_primary_attributions_scores
from ecco.output import OutputSeq
from typing import Optional, Any, List, Tuple, Dict
from typing import Optional, Any, List, Tuple, Dict, Union
from operator import attrgetter
import re
from ecco.util import is_partial_token, strip_tokenizer_prefix
Expand Down Expand Up @@ -107,7 +109,7 @@ def _reset(self):
self.neurons_to_induce = {}
self._hooks = {}

def to(self, tensor: torch.Tensor):
def to(self, tensor: Union[torch.Tensor, BatchEncoding]):
if self.device == 'cuda':
return tensor.to('cuda')
return tensor
Expand Down Expand Up @@ -163,7 +165,7 @@ def generate(self, input_str: str,
do_sample: Decoding parameter. If set to False, the model always always
chooses the highest scoring candidate output
token. This may lead to repetitive text. If set to True, the model considers
consults top_k and/or top_p to generate more itneresting output.
consults top_k and/or top_p to generate more interesting output.
attribution: List of attribution methods to be calculated. By default, it does not calculate anything.
beam_size: Beam size to consider while generating
generate_kwargs: Other arguments to be passed directly to self.model.generate
Expand All @@ -181,6 +183,7 @@ def generate(self, input_str: str,

# We need this as a batch in order to collect activations.
input_tokenized_info = self.tokenizer(input_str, return_tensors="pt")
input_tokenized_info = self.to(input_tokenized_info)
input_ids, attention_mask = input_tokenized_info['input_ids'], input_tokenized_info['attention_mask']
n_input_tokens = len(input_ids[0])
cur_len = n_input_tokens
Expand Down Expand Up @@ -277,9 +280,15 @@ def generate(self, input_str: str,
# Recomputing inputs ids, attention mask and decoder input ids
if decoder_input_ids is not None:
assert len(decoder_input_ids.size()) == 2 # will break otherwise
decoder_input_ids = torch.cat([decoder_input_ids, torch.tensor([[prediction_id]])], dim=-1)
decoder_input_ids = torch.cat(
[decoder_input_ids, torch.tensor([[prediction_id]], device=decoder_input_ids.device)],
dim=-1
)
else:
input_ids = torch.cat([input_ids, torch.tensor([[prediction_id]])], dim=-1)
input_ids = torch.cat(
[input_ids, torch.tensor([[prediction_id]], device=input_ids.device)],
dim=-1
)

# Recomputing Attention Mask
if getattr(self.model, '_prepare_attention_mask_for_generation'):
Expand Down Expand Up @@ -719,12 +728,13 @@ def sample_output_token(scores, do_sample, temperature, top_k, top_p):
return prediction_id


def _one_hot(token_ids, vocab_size):
return torch.zeros(len(token_ids), vocab_size).scatter_(1, token_ids.unsqueeze(1), 1.)
def _one_hot(token_ids: torch.Tensor, vocab_size: int) -> torch.Tensor:
return torch.zeros(len(token_ids), vocab_size, device=token_ids.device).scatter_(1, token_ids.unsqueeze(1), 1.)


def _one_hot_batched(token_ids, vocab_size):
def _one_hot_batched(token_ids: torch.Tensor, vocab_size: int) -> torch.Tensor:
batch_size, num_tokens = token_ids.shape
return torch.zeros(batch_size, num_tokens, vocab_size).scatter_(-1, token_ids.unsqueeze(-1), 1.)
return torch.zeros(batch_size, num_tokens, vocab_size, device=token_ids.device).scatter_(-1, token_ids.unsqueeze(-1), 1.)


def activations_dict_to_array(activations_dict):
Expand Down

0 comments on commit 9aa8639

Please sign in to comment.