Skip to content

Commit

Permalink
feat: allow setting max prompts in LRE training
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed May 2, 2024
1 parent d35bc8b commit b322a94
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 5 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ v_{o} = W^{\dagger}(o - b)
$$

For more information on LREs and LRCs, check out the following papers:

- [Identifying Linear Relational Concepts in Large Language Models](https://arxiv.org/abs/2311.08968)
- [Linearity of Relation Decoding in Transformer Language Models](https://arxiv.org/abs/2308.09124)

Expand Down
18 changes: 13 additions & 5 deletions linear_relational/training/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@ def train_lre(
subject_layer: int,
object_layer: int,
prompts: list[Prompt],
max_lre_training_samples: int | None = 20,
object_aggregation: ObjectAggregation = "mean",
validate_prompts: bool = True,
validate_prompts_batch_size: int = 4,
move_to_cpu: bool = False,
verbose: bool = True,
seed: int | str | float = 42,
) -> Lre:
processed_prompts = self._process_relation_prompts(
relation=relation,
Expand All @@ -65,14 +67,20 @@ def train_lre(
validate_prompts_batch_size=validate_prompts_batch_size,
verbose=verbose,
)
prompts_by_object = group_items(processed_prompts, lambda p: p.object_name)
lre_train_prompts = balance_grouped_items(
items_by_group=prompts_by_object,
max_total=max_lre_training_samples,
seed=seed,
)
return train_lre(
model=self.model,
tokenizer=self.tokenizer,
layer_matcher=self.layer_matcher,
relation=relation,
subject_layer=subject_layer,
object_layer=object_layer,
prompts=processed_prompts,
prompts=lre_train_prompts,
object_aggregation=object_aggregation,
move_to_cpu=move_to_cpu,
)
Expand Down Expand Up @@ -110,15 +118,15 @@ def train_relation_concepts(
max_total=max_lre_training_samples,
seed=seed,
)
inv_lre = self.train_lre(
inv_lre = train_lre(
model=self.model,
tokenizer=self.tokenizer,
layer_matcher=self.layer_matcher,
relation=relation,
subject_layer=subject_layer,
object_layer=object_layer,
prompts=lre_train_prompts,
object_aggregation=object_aggregation,
validate_prompts=False, # we already validated the prompts above
validate_prompts_batch_size=validate_prompts_batch_size,
verbose=verbose,
).invert(inv_lre_rank)

return self.train_relation_concepts_from_inv_lre(
Expand Down
51 changes: 51 additions & 0 deletions tests/training/test_Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,57 @@ def prompts_from_samples(samples: list[tuple[str, str]], template: str) -> list[
return prompts


def test_Trainer_train_lre(
model: GPT2LMHeadModel, tokenizer: GPT2TokenizerFast
) -> None:
template = "{} is located in the country of"
japan_cities = [
"Tokyo",
"Osaka",
"Nagoya",
"Hiroshima",
"Yokohama",
"Kyoto",
"Nagasaki",
"Kobe",
"Kitashima",
"Kyushu",
]
china_cities = [
"Beijing",
"Shanghai",
"Nanjing",
"Hangzhou",
"Peking",
"Qingdao",
"Chongqing",
"Changsha",
"Wuhan",
"Chengdu",
]
samples: list[tuple[str, str]] = []
for city in japan_cities:
samples.append((city, "Japan"))
for city in china_cities:
samples.append((city, "China"))
samples = stable_shuffle(samples)
prompts = prompts_from_samples(samples, template)

trainer = Trainer(model, tokenizer)

lre = trainer.train_lre(
relation="located_in_country",
subject_layer=8,
object_layer=10,
prompts=prompts,
)

assert lre.subject_layer == 8
assert lre.object_layer == 10
assert lre.weight.shape == (768, 768)
assert lre.bias.shape == (768,)


def test_Trainer_train_relation_concepts(
model: GPT2LMHeadModel, tokenizer: GPT2TokenizerFast
) -> None:
Expand Down

0 comments on commit b322a94

Please sign in to comment.