From 0478b4842a2614bceed96534d39ea71b62efbff4 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Wed, 7 Aug 2024 16:47:29 +0100 Subject: [PATCH] fix: fix cache behavior for gemma2 and warn if caching breaks --- linear_relational/training/train_lre.py | 33 +++++++++++++--- tests/conftest.py | 18 ++++++++- tests/training/test_train_lre.py | 51 ++++++++++++++++++++++++- 3 files changed, 95 insertions(+), 7 deletions(-) diff --git a/linear_relational/training/train_lre.py b/linear_relational/training/train_lre.py index cb32887..6268fa6 100644 --- a/linear_relational/training/train_lre.py +++ b/linear_relational/training/train_lre.py @@ -3,12 +3,14 @@ import torch from tokenizers import Tokenizer from torch import nn +from transformers.cache_utils import HybridCache from linear_relational.lib.layer_matching import ( LayerMatcher, fix_neg_layer_num, get_layer_name, ) +from linear_relational.lib.logger import logger from linear_relational.lib.token_utils import ( find_final_word_token_index, find_prompt_answer_data, @@ -102,16 +104,36 @@ def order_1_approx( inputs = tokenizer(prompt_text, return_tensors="pt").to(device) # Precompute everything up to the subject, if there is anything before it. - past_key_values = None input_ids = inputs.input_ids + # lots of cache pos hackery to get caching to work with gemma2: https://github.com/huggingface/transformers/issues/31981 + precache_extra_params = {} + postcache_extra_params = {} + if ( + hasattr(model, "config") + and getattr(model.config, "cache_implementation", None) == "hybrid" + ): + cache = HybridCache(model.config, input_ids.shape[0], input_ids.shape[1] + 1) + cache_position = torch.arange(input_ids.shape[1], device=device) + precache_extra_params["past_key_values"] = cache + precache_extra_params["cache_position"] = cache_position[:subject_index] + postcache_extra_params["cache_position"] = cache_position[subject_index:] _subject_index = subject_index _object_pred_indices = object_pred_indices if _subject_index > 0: - outputs = model(input_ids=input_ids[:, :_subject_index], use_cache=True) + outputs = model( + input_ids=input_ids[:, :_subject_index], + use_cache=True, + **precache_extra_params, + ) past_key_values = outputs.past_key_values - input_ids = input_ids[:, _subject_index:] - _subject_index = 0 - _object_pred_indices = [i - subject_index for i in object_pred_indices] + if past_key_values is None: + logger.warn( + "Model did not return past_key_values, so the cache will not be used. This may use a lot of memory." + ) + else: + input_ids = input_ids[:, _subject_index:] + _subject_index = 0 + _object_pred_indices = [i - subject_index for i in object_pred_indices] use_cache = past_key_values is not None # Precompute initial h and z. @@ -122,6 +144,7 @@ def order_1_approx( input_ids=input_ids, use_cache=use_cache, past_key_values=past_key_values, + **postcache_extra_params, ) subject_layer_output = ret[subject_layer_name].output assert subject_layer_output is not None # keep mypy happy diff --git a/tests/conftest.py b/tests/conftest.py index 3faeb0c..a4792ad 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,10 @@ import pytest -from transformers import GPT2LMHeadModel, GPT2TokenizerFast, LlamaTokenizer +from transformers import ( + GPT2LMHeadModel, + GPT2TokenizerFast, + LlamaTokenizer, +) +from transformers.models.gemma2.modeling_gemma2 import Gemma2Config, Gemma2ForCausalLM # loading in advance so it won't reload on every test # just need to make sure not to edit these models in tests... @@ -13,6 +18,17 @@ def model() -> GPT2LMHeadModel: return _model +@pytest.fixture +def empty_gemma2_model() -> Gemma2ForCausalLM: + config = Gemma2Config( + num_hidden_layers=3, + hidden_size=1024, + intermediate_size=2752, + vocab_size=_tokenizer.vocab_size, + ) + return Gemma2ForCausalLM(config).eval() + + @pytest.fixture def tokenizer() -> GPT2TokenizerFast: return _tokenizer diff --git a/tests/training/test_train_lre.py b/tests/training/test_train_lre.py index de315b1..2b3d3df 100644 --- a/tests/training/test_train_lre.py +++ b/tests/training/test_train_lre.py @@ -1,5 +1,5 @@ import torch -from transformers import GPT2LMHeadModel, GPT2TokenizerFast +from transformers import GPT2LMHeadModel, GPT2TokenizerFast, PreTrainedModel from linear_relational.lib.extract_token_activations import ( extract_final_token_activations, @@ -92,3 +92,52 @@ def test_train_lre_on_single_prompt_perfectly_replicates_object( layers=["transformer.h.9"], )[0]["transformer.h.9"] assert torch.allclose(lre(subj_act), obj_act, atol=1e-4) + + +def test_train_lre_on_single_prompt_with_gemma2_perfectly_replicates_object( + empty_gemma2_model: PreTrainedModel, tokenizer: GPT2TokenizerFast +) -> None: + fsl_prefixes = "\n".join( + [ + "Berlin is located in the country of Germany", + "Toronto is located in the country of Canada", + "Lagos is located in the country of Nigeria", + ] + ) + prompt = create_prompt( + text=f"{fsl_prefixes}\nTokyo is located in the country of", + answer="Japan", + subject="Tokyo", + ) + prompts = [prompt] + lre = train_lre( + model=empty_gemma2_model, + tokenizer=tokenizer, + layer_matcher="model.layers.{num}", + relation="city in country", + subject_layer=1, + object_layer=2, + prompts=prompts, + object_aggregation="mean", + ) + + subj_index = ( + find_token_range(tokenizer, tokenizer.encode(prompt.text), prompt.subject)[-1] + - 1 + ) + subj_act = extract_token_activations( + model=empty_gemma2_model, + tokenizer=tokenizer, + texts=[prompt.text], + layers=["model.layers.1"], + token_indices=[subj_index], + )[0]["model.layers.1"][0] + obj_act = extract_final_token_activations( + model=empty_gemma2_model, + tokenizer=tokenizer, + texts=[prompt.text], + layers=["model.layers.2"], + )[0]["model.layers.2"] + print(lre(subj_act)) + print(obj_act) + assert torch.allclose(lre(subj_act), obj_act, atol=1e-4)