diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c61855..446c329 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Fixed + +- Added `torch.no_grad()` around model calls in `language_model.py` +- Prevent crashes with more robust stop token for `greedy_until` in `language_model.py` + ## [v1.0.0rc0](https://github.com/allenai/catwalk/releases/tag/v1.0.0rc0) - 2023-12-19 ### Added diff --git a/catwalk/models/language_model.py b/catwalk/models/language_model.py index e81976f..63ca1b1 100644 --- a/catwalk/models/language_model.py +++ b/catwalk/models/language_model.py @@ -567,7 +567,8 @@ def _run_loglikelihood_tokens( for field_name, tensors in unpadded_batch.items() } - batch_logits = log_softmax(model(**padded_batch)[0], dim=-1) + with torch.no_grad(): + batch_logits = log_softmax(model(**padded_batch)[0], dim=-1) z = zip( batch_of_indices, batch_logits, @@ -642,8 +643,8 @@ def _run_greedy_until( if isinstance(untils, str): untils = [untils] # if any of the stop phrases are single tokens we can use that for early termination - primary_until = None - for tokenized_until in tokenizer(untils)["input_ids"]: + primary_until = tokenizer.eos_token_id + for tokenized_until in tokenizer(untils, add_special_tokens=False)["input_ids"]: if len(tokenized_until) == 1: primary_until = tokenized_until[0] @@ -652,13 +653,14 @@ def _run_greedy_until( [tokenized_context[max_gen_toks - model_max_length :]] ).to(model.device) - full_text_tensor = model.generate( - context_tensor, - max_length=context_tensor.shape[1] + max_gen_toks, - eos_token_id=primary_until, - do_sample=False, - pad_token_id=primary_until, # temporary hack to suppress irrelevant warning until batch processing is added - ) + with torch.no_grad(): + full_text_tensor = model.generate( + context_tensor, + max_length=context_tensor.shape[1] + max_gen_toks, + eos_token_id=primary_until, + do_sample=False, + pad_token_id=primary_until, # temporary hack to suppress irrelevant warning until batch processing is added + ) continuation_tensor = full_text_tensor[0, context_tensor.shape[1] :] continuation = tokenizer.decode(continuation_tensor.tolist()) raw_continuation = continuation