Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add map_activations_fn param in ConceptMatcher #6

Merged
merged 2 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion linear_relational/ConceptMatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from torch import nn

from linear_relational.Concept import Concept
from linear_relational.lib.extract_token_activations import extract_token_activations
from linear_relational.lib.extract_token_activations import (
TokenLayerActivationsList,
extract_token_activations,
)
from linear_relational.lib.layer_matching import (
LayerMatcher,
collect_matching_layers,
Expand Down Expand Up @@ -52,18 +55,25 @@ class ConceptMatcher:
tokenizer: Tokenizer
layer_matcher: LayerMatcher
layer_name_to_num: dict[str, int]
map_activations_fn: (
Callable[[TokenLayerActivationsList], TokenLayerActivationsList] | None
)

def __init__(
self,
model: nn.Module,
tokenizer: Tokenizer,
concepts: list[Concept],
layer_matcher: Optional[LayerMatcher] = None,
map_activations_fn: (
Callable[[TokenLayerActivationsList], TokenLayerActivationsList] | None
) = None,
) -> None:
self.concepts = concepts
self.model = model
self.tokenizer = tokenizer
self.layer_matcher = layer_matcher or guess_hidden_layer_matcher(model)
self.map_activations_fn = map_activations_fn
ensure_tokenizer_has_pad_token(tokenizer)
num_layers = len(collect_matching_layers(self.model, self.layer_matcher))
self.layer_name_to_num = {}
Expand Down Expand Up @@ -100,6 +110,10 @@ def _query_batch(self, queries: Sequence[ConceptMatchQuery]) -> list[QueryResult
batch_size=len(queries),
show_progress=False,
)
if self.map_activations_fn is not None:
batch_subj_token_activations = self.map_activations_fn(
batch_subj_token_activations
)

results: list[QueryResult] = []
for raw_subj_token_activations in batch_subj_token_activations:
Expand Down
6 changes: 4 additions & 2 deletions linear_relational/lib/extract_token_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from .TraceLayerDict import TraceLayerDict
from .util import batchify, tuplify

TokenLayerActivationsList = list[OrderedDict[str, list[torch.Tensor]]]


def extract_token_activations(
model: nn.Module,
Expand All @@ -22,12 +24,12 @@ def extract_token_activations(
move_results_to_cpu: bool = True,
batch_size: int = 32,
show_progress: bool = False,
) -> list[OrderedDict[str, list[torch.Tensor]]]:
) -> TokenLayerActivationsList:
if len(texts) != len(token_indices):
raise ValueError(
f"Expected {len(texts)} texts to match {len(token_indices)} subject token indices"
)
results: list[OrderedDict[str, list[torch.Tensor]]] = []
results: TokenLayerActivationsList = []
for batch in batchify(
# need to turn the zip into a list or mypy complains
list(zip(texts, token_indices)),
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ readme = "README.md"

[tool.poetry.dependencies]
python = "^3.10"
transformers = "^4.35.2"
transformers = "<4.42.0"
tqdm = ">=4.0.0"
dataclasses-json = "^0.6.2"

Expand Down
4 changes: 4 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,7 @@ def quick_concept(
vector=concept_vec,
layer=layer,
)


def normalize(vec: torch.Tensor) -> torch.Tensor:
return vec / vec.norm()
56 changes: 56 additions & 0 deletions tests/test_ConceptMatcher.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from collections import OrderedDict

import pytest
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

from linear_relational.Concept import Concept
from linear_relational.ConceptMatcher import ConceptMatcher, ConceptMatchQuery
from linear_relational.lib.extract_token_activations import TokenLayerActivationsList
from tests.helpers import normalize


def test_ConceptMatcher_query(
Expand Down Expand Up @@ -44,3 +49,54 @@ def test_ConceptMatcher_query_bulk(
assert len(results) == 2
for result in results:
assert concept.name in result.concept_results


def test_ConceptMatcher_query_bulk_with_map_activations_fn(
model: GPT2LMHeadModel, tokenizer: GPT2TokenizerFast
) -> None:
concept_vec = normalize(torch.randn(768))
concept1 = Concept(
object="test_object1",
relation="test_relation",
layer=10,
vector=concept_vec,
)
concept2 = Concept(
object="test_object2",
relation="test_relation",
layer=10,
vector=-1 * concept_vec,
)

def map_activations(
token_layer_acts: TokenLayerActivationsList,
) -> TokenLayerActivationsList:
mapped_token_layer_acts: TokenLayerActivationsList = []
for token_layer_act in token_layer_acts:
mapped_token_layer_act = OrderedDict()
for layer, acts in token_layer_act.items():
mapped_token_layer_act[layer] = [concept_vec for act in acts]
mapped_token_layer_acts.append(mapped_token_layer_act)
return mapped_token_layer_acts

conceptifier = ConceptMatcher(
model,
tokenizer,
[concept1, concept2],
layer_matcher="transformer.h.{num}",
map_activations_fn=map_activations,
)
results = conceptifier.query_bulk(
[
ConceptMatchQuery("This is a test", "test"),
ConceptMatchQuery("This is another test", "test"),
]
)
assert len(results) == 2
for result in results:
assert result.concept_results[
"test_relation: test_object1"
].score == pytest.approx(1.0)
assert result.concept_results[
"test_relation: test_object2"
].score == pytest.approx(-1.0)