From 4aeca8dac6c30b15e769c0ae3df38bbaef3e7dd7 Mon Sep 17 00:00:00 2001 From: Oyvind Tafjord Date: Fri, 19 Apr 2024 09:16:38 -0700 Subject: [PATCH 1/3] Add torch.no_grad wrappers around model calls --- catwalk/models/language_model.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/catwalk/models/language_model.py b/catwalk/models/language_model.py index e81976f..89af4f0 100644 --- a/catwalk/models/language_model.py +++ b/catwalk/models/language_model.py @@ -567,7 +567,8 @@ def _run_loglikelihood_tokens( for field_name, tensors in unpadded_batch.items() } - batch_logits = log_softmax(model(**padded_batch)[0], dim=-1) + with torch.no_grad(): + batch_logits = log_softmax(model(**padded_batch)[0], dim=-1) z = zip( batch_of_indices, batch_logits, @@ -652,13 +653,14 @@ def _run_greedy_until( [tokenized_context[max_gen_toks - model_max_length :]] ).to(model.device) - full_text_tensor = model.generate( - context_tensor, - max_length=context_tensor.shape[1] + max_gen_toks, - eos_token_id=primary_until, - do_sample=False, - pad_token_id=primary_until, # temporary hack to suppress irrelevant warning until batch processing is added - ) + with torch.no_grad(): + full_text_tensor = model.generate( + context_tensor, + max_length=context_tensor.shape[1] + max_gen_toks, + eos_token_id=primary_until, + do_sample=False, + pad_token_id=primary_until, # temporary hack to suppress irrelevant warning until batch processing is added + ) continuation_tensor = full_text_tensor[0, context_tensor.shape[1] :] continuation = tokenizer.decode(continuation_tensor.tolist()) raw_continuation = continuation From 1f8c34fa5e3e9609442c8dae38fa795a2650aedc Mon Sep 17 00:00:00 2001 From: Oyvind Tafjord Date: Fri, 19 Apr 2024 09:44:06 -0700 Subject: [PATCH 2/3] Robustify greedy_until stop condition --- catwalk/models/language_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/catwalk/models/language_model.py b/catwalk/models/language_model.py index 89af4f0..63ca1b1 100644 --- a/catwalk/models/language_model.py +++ b/catwalk/models/language_model.py @@ -643,8 +643,8 @@ def _run_greedy_until( if isinstance(untils, str): untils = [untils] # if any of the stop phrases are single tokens we can use that for early termination - primary_until = None - for tokenized_until in tokenizer(untils)["input_ids"]: + primary_until = tokenizer.eos_token_id + for tokenized_until in tokenizer(untils, add_special_tokens=False)["input_ids"]: if len(tokenized_until) == 1: primary_until = tokenized_until[0] From 6f0389ce1b7e8c991c4b2a5664267dab643e4ca9 Mon Sep 17 00:00:00 2001 From: Oyvind Tafjord Date: Fri, 19 Apr 2024 09:48:54 -0700 Subject: [PATCH 3/3] Update CHANGELOG.md --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c61855..446c329 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Fixed + +- Added `torch.no_grad()` around model calls in `language_model.py` +- Prevent crashes with more robust stop token for `greedy_until` in `language_model.py` + ## [v1.0.0rc0](https://github.com/allenai/catwalk/releases/tag/v1.0.0rc0) - 2023-12-19 ### Added