diff --git a/linear_relational/lib/torch_utils.py b/linear_relational/lib/torch_utils.py index a1a4eeb..43049d7 100644 --- a/linear_relational/lib/torch_utils.py +++ b/linear_relational/lib/torch_utils.py @@ -18,6 +18,13 @@ def get_module(model: nn.Module, name: str) -> nn.Module: raise LookupError(name) +def get_dtype(model: nn.Module) -> torch.dtype: + """ + Returns the dtype of the model. + """ + return next(model.parameters()).dtype + + def get_device(model: nn.Module) -> torch.device: """ Returns the device on which the model is running. diff --git a/linear_relational/training/train_lre.py b/linear_relational/training/train_lre.py index 6268fa6..9be23fa 100644 --- a/linear_relational/training/train_lre.py +++ b/linear_relational/training/train_lre.py @@ -15,7 +15,11 @@ find_final_word_token_index, find_prompt_answer_data, ) -from linear_relational.lib.torch_utils import get_device, untuple_tensor +from linear_relational.lib.torch_utils import ( + get_device, + get_dtype, + untuple_tensor, +) from linear_relational.lib.TraceLayer import TraceLayer from linear_relational.lib.TraceLayerDict import TraceLayerDict from linear_relational.Lre import Lre @@ -112,7 +116,12 @@ def order_1_approx( hasattr(model, "config") and getattr(model.config, "cache_implementation", None) == "hybrid" ): - cache = HybridCache(model.config, input_ids.shape[0], input_ids.shape[1] + 1) + cache = HybridCache( + model.config, + input_ids.shape[0], + input_ids.shape[1] + 1, + dtype=get_dtype(model), + ) cache_position = torch.arange(input_ids.shape[1], device=device) precache_extra_params["past_key_values"] = cache precache_extra_params["cache_position"] = cache_position[:subject_index] diff --git a/tests/conftest.py b/tests/conftest.py index a4792ad..dcb4a35 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,8 +22,8 @@ def model() -> GPT2LMHeadModel: def empty_gemma2_model() -> Gemma2ForCausalLM: config = Gemma2Config( num_hidden_layers=3, - hidden_size=1024, - intermediate_size=2752, + hidden_size=64, + intermediate_size=128, vocab_size=_tokenizer.vocab_size, ) return Gemma2ForCausalLM(config).eval() diff --git a/tests/training/test_train_lre.py b/tests/training/test_train_lre.py index 2b3d3df..23b5b58 100644 --- a/tests/training/test_train_lre.py +++ b/tests/training/test_train_lre.py @@ -138,6 +138,43 @@ def test_train_lre_on_single_prompt_with_gemma2_perfectly_replicates_object( texts=[prompt.text], layers=["model.layers.2"], )[0]["model.layers.2"] - print(lre(subj_act)) - print(obj_act) + assert torch.allclose(lre(subj_act), obj_act, atol=1e-4) + + +def test_train_lre_works_with_gemma2_and_float16( + empty_gemma2_model: PreTrainedModel, tokenizer: GPT2TokenizerFast +) -> None: + model = empty_gemma2_model.half() + prompt = create_prompt( + text="Tokyo is located in the country of", + answer="Japan", + subject="Tokyo", + ) + lre = train_lre( + model=model, + tokenizer=tokenizer, + layer_matcher="model.layers.{num}", + relation="city in country", + subject_layer=1, + object_layer=2, + prompts=[prompt], + ).float() + + subj_index = ( + find_token_range(tokenizer, tokenizer.encode(prompt.text), prompt.subject)[-1] + - 1 + ) + subj_act = extract_token_activations( + model=model, + tokenizer=tokenizer, + texts=[prompt.text], + layers=["model.layers.1"], + token_indices=[subj_index], + )[0]["model.layers.1"][0] + obj_act = extract_final_token_activations( + model=model, + tokenizer=tokenizer, + texts=[prompt.text], + layers=["model.layers.2"], + )[0]["model.layers.2"] assert torch.allclose(lre(subj_act), obj_act, atol=1e-4)