Skip to content

Commit

Permalink
feat: allow passing a map_activations_fn to ConceptMatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Jun 29, 2024
1 parent eb34692 commit 4ccb58e
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 3 deletions.
16 changes: 15 additions & 1 deletion linear_relational/ConceptMatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from torch import nn

from linear_relational.Concept import Concept
from linear_relational.lib.extract_token_activations import extract_token_activations
from linear_relational.lib.extract_token_activations import (
TokenLayerActivationsList,
extract_token_activations,
)
from linear_relational.lib.layer_matching import (
LayerMatcher,
collect_matching_layers,
Expand Down Expand Up @@ -52,18 +55,25 @@ class ConceptMatcher:
tokenizer: Tokenizer
layer_matcher: LayerMatcher
layer_name_to_num: dict[str, int]
map_activations_fn: (
Callable[[TokenLayerActivationsList], TokenLayerActivationsList] | None
)

def __init__(
self,
model: nn.Module,
tokenizer: Tokenizer,
concepts: list[Concept],
layer_matcher: Optional[LayerMatcher] = None,
map_activations_fn: (
Callable[[TokenLayerActivationsList], TokenLayerActivationsList] | None
) = None,
) -> None:
self.concepts = concepts
self.model = model
self.tokenizer = tokenizer
self.layer_matcher = layer_matcher or guess_hidden_layer_matcher(model)
self.map_activations_fn = map_activations_fn
ensure_tokenizer_has_pad_token(tokenizer)
num_layers = len(collect_matching_layers(self.model, self.layer_matcher))
self.layer_name_to_num = {}
Expand Down Expand Up @@ -100,6 +110,10 @@ def _query_batch(self, queries: Sequence[ConceptMatchQuery]) -> list[QueryResult
batch_size=len(queries),
show_progress=False,
)
if self.map_activations_fn is not None:
batch_subj_token_activations = self.map_activations_fn(
batch_subj_token_activations
)

results: list[QueryResult] = []
for raw_subj_token_activations in batch_subj_token_activations:
Expand Down
6 changes: 4 additions & 2 deletions linear_relational/lib/extract_token_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from .TraceLayerDict import TraceLayerDict
from .util import batchify, tuplify

TokenLayerActivationsList = list[OrderedDict[str, list[torch.Tensor]]]


def extract_token_activations(
model: nn.Module,
Expand All @@ -22,12 +24,12 @@ def extract_token_activations(
move_results_to_cpu: bool = True,
batch_size: int = 32,
show_progress: bool = False,
) -> list[OrderedDict[str, list[torch.Tensor]]]:
) -> TokenLayerActivationsList:
if len(texts) != len(token_indices):
raise ValueError(
f"Expected {len(texts)} texts to match {len(token_indices)} subject token indices"
)
results: list[OrderedDict[str, list[torch.Tensor]]] = []
results: TokenLayerActivationsList = []
for batch in batchify(
# need to turn the zip into a list or mypy complains
list(zip(texts, token_indices)),
Expand Down
4 changes: 4 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,7 @@ def quick_concept(
vector=concept_vec,
layer=layer,
)


def normalize(vec: torch.Tensor) -> torch.Tensor:
return vec / vec.norm()
56 changes: 56 additions & 0 deletions tests/test_ConceptMatcher.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from collections import OrderedDict

import pytest
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

from linear_relational.Concept import Concept
from linear_relational.ConceptMatcher import ConceptMatcher, ConceptMatchQuery
from linear_relational.lib.extract_token_activations import TokenLayerActivationsList
from tests.helpers import normalize


def test_ConceptMatcher_query(
Expand Down Expand Up @@ -44,3 +49,54 @@ def test_ConceptMatcher_query_bulk(
assert len(results) == 2
for result in results:
assert concept.name in result.concept_results


def test_ConceptMatcher_query_bulk_with_map_activations_fn(
model: GPT2LMHeadModel, tokenizer: GPT2TokenizerFast
) -> None:
concept_vec = normalize(torch.randn(768))
concept1 = Concept(
object="test_object1",
relation="test_relation",
layer=10,
vector=concept_vec,
)
concept2 = Concept(
object="test_object2",
relation="test_relation",
layer=10,
vector=-1 * concept_vec,
)

def map_activations(
token_layer_acts: TokenLayerActivationsList,
) -> TokenLayerActivationsList:
mapped_token_layer_acts: TokenLayerActivationsList = []
for token_layer_act in token_layer_acts:
mapped_token_layer_act = OrderedDict()
for layer, acts in token_layer_act.items():
mapped_token_layer_act[layer] = [concept_vec for act in acts]
mapped_token_layer_acts.append(mapped_token_layer_act)
return mapped_token_layer_acts

conceptifier = ConceptMatcher(
model,
tokenizer,
[concept1, concept2],
layer_matcher="transformer.h.{num}",
map_activations_fn=map_activations,
)
results = conceptifier.query_bulk(
[
ConceptMatchQuery("This is a test", "test"),
ConceptMatchQuery("This is another test", "test"),
]
)
assert len(results) == 2
for result in results:
assert result.concept_results[
"test_relation: test_object1"
].score == pytest.approx(1.0)
assert result.concept_results[
"test_relation: test_object2"
].score == pytest.approx(-1.0)

0 comments on commit 4ccb58e

Please sign in to comment.