diff --git a/linear_relational/ConceptMatcher.py b/linear_relational/ConceptMatcher.py index 7a405f2..ebbbe81 100644 --- a/linear_relational/ConceptMatcher.py +++ b/linear_relational/ConceptMatcher.py @@ -6,7 +6,10 @@ from torch import nn from linear_relational.Concept import Concept -from linear_relational.lib.extract_token_activations import extract_token_activations +from linear_relational.lib.extract_token_activations import ( + TokenLayerActivationsList, + extract_token_activations, +) from linear_relational.lib.layer_matching import ( LayerMatcher, collect_matching_layers, @@ -52,6 +55,9 @@ class ConceptMatcher: tokenizer: Tokenizer layer_matcher: LayerMatcher layer_name_to_num: dict[str, int] + map_activations_fn: ( + Callable[[TokenLayerActivationsList], TokenLayerActivationsList] | None + ) def __init__( self, @@ -59,11 +65,15 @@ def __init__( tokenizer: Tokenizer, concepts: list[Concept], layer_matcher: Optional[LayerMatcher] = None, + map_activations_fn: ( + Callable[[TokenLayerActivationsList], TokenLayerActivationsList] | None + ) = None, ) -> None: self.concepts = concepts self.model = model self.tokenizer = tokenizer self.layer_matcher = layer_matcher or guess_hidden_layer_matcher(model) + self.map_activations_fn = map_activations_fn ensure_tokenizer_has_pad_token(tokenizer) num_layers = len(collect_matching_layers(self.model, self.layer_matcher)) self.layer_name_to_num = {} @@ -100,6 +110,10 @@ def _query_batch(self, queries: Sequence[ConceptMatchQuery]) -> list[QueryResult batch_size=len(queries), show_progress=False, ) + if self.map_activations_fn is not None: + batch_subj_token_activations = self.map_activations_fn( + batch_subj_token_activations + ) results: list[QueryResult] = [] for raw_subj_token_activations in batch_subj_token_activations: diff --git a/linear_relational/lib/extract_token_activations.py b/linear_relational/lib/extract_token_activations.py index a0e455c..e8b1083 100644 --- a/linear_relational/lib/extract_token_activations.py +++ b/linear_relational/lib/extract_token_activations.py @@ -11,6 +11,8 @@ from .TraceLayerDict import TraceLayerDict from .util import batchify, tuplify +TokenLayerActivationsList = list[OrderedDict[str, list[torch.Tensor]]] + def extract_token_activations( model: nn.Module, @@ -22,12 +24,12 @@ def extract_token_activations( move_results_to_cpu: bool = True, batch_size: int = 32, show_progress: bool = False, -) -> list[OrderedDict[str, list[torch.Tensor]]]: +) -> TokenLayerActivationsList: if len(texts) != len(token_indices): raise ValueError( f"Expected {len(texts)} texts to match {len(token_indices)} subject token indices" ) - results: list[OrderedDict[str, list[torch.Tensor]]] = [] + results: TokenLayerActivationsList = [] for batch in batchify( # need to turn the zip into a list or mypy complains list(zip(texts, token_indices)), diff --git a/tests/helpers.py b/tests/helpers.py index 8c9026c..764b8da 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -43,3 +43,7 @@ def quick_concept( vector=concept_vec, layer=layer, ) + + +def normalize(vec: torch.Tensor) -> torch.Tensor: + return vec / vec.norm() diff --git a/tests/test_ConceptMatcher.py b/tests/test_ConceptMatcher.py index 7c1ecd4..3bf6432 100644 --- a/tests/test_ConceptMatcher.py +++ b/tests/test_ConceptMatcher.py @@ -1,8 +1,13 @@ +from collections import OrderedDict + +import pytest import torch from transformers import GPT2LMHeadModel, GPT2TokenizerFast from linear_relational.Concept import Concept from linear_relational.ConceptMatcher import ConceptMatcher, ConceptMatchQuery +from linear_relational.lib.extract_token_activations import TokenLayerActivationsList +from tests.helpers import normalize def test_ConceptMatcher_query( @@ -44,3 +49,54 @@ def test_ConceptMatcher_query_bulk( assert len(results) == 2 for result in results: assert concept.name in result.concept_results + + +def test_ConceptMatcher_query_bulk_with_map_activations_fn( + model: GPT2LMHeadModel, tokenizer: GPT2TokenizerFast +) -> None: + concept_vec = normalize(torch.randn(768)) + concept1 = Concept( + object="test_object1", + relation="test_relation", + layer=10, + vector=concept_vec, + ) + concept2 = Concept( + object="test_object2", + relation="test_relation", + layer=10, + vector=-1 * concept_vec, + ) + + def map_activations( + token_layer_acts: TokenLayerActivationsList, + ) -> TokenLayerActivationsList: + mapped_token_layer_acts: TokenLayerActivationsList = [] + for token_layer_act in token_layer_acts: + mapped_token_layer_act = OrderedDict() + for layer, acts in token_layer_act.items(): + mapped_token_layer_act[layer] = [concept_vec for act in acts] + mapped_token_layer_acts.append(mapped_token_layer_act) + return mapped_token_layer_acts + + conceptifier = ConceptMatcher( + model, + tokenizer, + [concept1, concept2], + layer_matcher="transformer.h.{num}", + map_activations_fn=map_activations, + ) + results = conceptifier.query_bulk( + [ + ConceptMatchQuery("This is a test", "test"), + ConceptMatchQuery("This is another test", "test"), + ] + ) + assert len(results) == 2 + for result in results: + assert result.concept_results[ + "test_relation: test_object1" + ].score == pytest.approx(1.0) + assert result.concept_results[ + "test_relation: test_object2" + ].score == pytest.approx(-1.0)