Skip to content

Commit

Permalink
fix: fix cache behavior for gemma2 and warn if caching breaks
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Aug 7, 2024
1 parent c26210f commit 0478b48
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 7 deletions.
33 changes: 28 additions & 5 deletions linear_relational/training/train_lre.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import torch
from tokenizers import Tokenizer
from torch import nn
from transformers.cache_utils import HybridCache

from linear_relational.lib.layer_matching import (
LayerMatcher,
fix_neg_layer_num,
get_layer_name,
)
from linear_relational.lib.logger import logger
from linear_relational.lib.token_utils import (
find_final_word_token_index,
find_prompt_answer_data,
Expand Down Expand Up @@ -102,16 +104,36 @@ def order_1_approx(
inputs = tokenizer(prompt_text, return_tensors="pt").to(device)

# Precompute everything up to the subject, if there is anything before it.
past_key_values = None
input_ids = inputs.input_ids
# lots of cache pos hackery to get caching to work with gemma2: https://github.com/huggingface/transformers/issues/31981
precache_extra_params = {}
postcache_extra_params = {}
if (
hasattr(model, "config")
and getattr(model.config, "cache_implementation", None) == "hybrid"
):
cache = HybridCache(model.config, input_ids.shape[0], input_ids.shape[1] + 1)
cache_position = torch.arange(input_ids.shape[1], device=device)
precache_extra_params["past_key_values"] = cache
precache_extra_params["cache_position"] = cache_position[:subject_index]
postcache_extra_params["cache_position"] = cache_position[subject_index:]
_subject_index = subject_index
_object_pred_indices = object_pred_indices
if _subject_index > 0:
outputs = model(input_ids=input_ids[:, :_subject_index], use_cache=True)
outputs = model(
input_ids=input_ids[:, :_subject_index],
use_cache=True,
**precache_extra_params,
)
past_key_values = outputs.past_key_values
input_ids = input_ids[:, _subject_index:]
_subject_index = 0
_object_pred_indices = [i - subject_index for i in object_pred_indices]
if past_key_values is None:
logger.warn(
"Model did not return past_key_values, so the cache will not be used. This may use a lot of memory."
)
else:
input_ids = input_ids[:, _subject_index:]
_subject_index = 0
_object_pred_indices = [i - subject_index for i in object_pred_indices]
use_cache = past_key_values is not None

# Precompute initial h and z.
Expand All @@ -122,6 +144,7 @@ def order_1_approx(
input_ids=input_ids,
use_cache=use_cache,
past_key_values=past_key_values,
**postcache_extra_params,
)
subject_layer_output = ret[subject_layer_name].output
assert subject_layer_output is not None # keep mypy happy
Expand Down
18 changes: 17 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import pytest
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, LlamaTokenizer
from transformers import (
GPT2LMHeadModel,
GPT2TokenizerFast,
LlamaTokenizer,
)
from transformers.models.gemma2.modeling_gemma2 import Gemma2Config, Gemma2ForCausalLM

# loading in advance so it won't reload on every test
# just need to make sure not to edit these models in tests...
Expand All @@ -13,6 +18,17 @@ def model() -> GPT2LMHeadModel:
return _model


@pytest.fixture
def empty_gemma2_model() -> Gemma2ForCausalLM:
config = Gemma2Config(
num_hidden_layers=3,
hidden_size=1024,
intermediate_size=2752,
vocab_size=_tokenizer.vocab_size,
)
return Gemma2ForCausalLM(config).eval()


@pytest.fixture
def tokenizer() -> GPT2TokenizerFast:
return _tokenizer
Expand Down
51 changes: 50 additions & 1 deletion tests/training/test_train_lre.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, PreTrainedModel

from linear_relational.lib.extract_token_activations import (
extract_final_token_activations,
Expand Down Expand Up @@ -92,3 +92,52 @@ def test_train_lre_on_single_prompt_perfectly_replicates_object(
layers=["transformer.h.9"],
)[0]["transformer.h.9"]
assert torch.allclose(lre(subj_act), obj_act, atol=1e-4)


def test_train_lre_on_single_prompt_with_gemma2_perfectly_replicates_object(
empty_gemma2_model: PreTrainedModel, tokenizer: GPT2TokenizerFast
) -> None:
fsl_prefixes = "\n".join(
[
"Berlin is located in the country of Germany",
"Toronto is located in the country of Canada",
"Lagos is located in the country of Nigeria",
]
)
prompt = create_prompt(
text=f"{fsl_prefixes}\nTokyo is located in the country of",
answer="Japan",
subject="Tokyo",
)
prompts = [prompt]
lre = train_lre(
model=empty_gemma2_model,
tokenizer=tokenizer,
layer_matcher="model.layers.{num}",
relation="city in country",
subject_layer=1,
object_layer=2,
prompts=prompts,
object_aggregation="mean",
)

subj_index = (
find_token_range(tokenizer, tokenizer.encode(prompt.text), prompt.subject)[-1]
- 1
)
subj_act = extract_token_activations(
model=empty_gemma2_model,
tokenizer=tokenizer,
texts=[prompt.text],
layers=["model.layers.1"],
token_indices=[subj_index],
)[0]["model.layers.1"][0]
obj_act = extract_final_token_activations(
model=empty_gemma2_model,
tokenizer=tokenizer,
texts=[prompt.text],
layers=["model.layers.2"],
)[0]["model.layers.2"]
print(lre(subj_act))
print(obj_act)
assert torch.allclose(lre(subj_act), obj_act, atol=1e-4)

0 comments on commit 0478b48

Please sign in to comment.