-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
30 changed files
with
3,513 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
name: CI | ||
on: [push] | ||
jobs: | ||
lint_test_and_build: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- uses: actions/setup-python@v3 | ||
with: | ||
python-version: "3.11" | ||
- name: Install Poetry | ||
uses: snok/install-poetry@v1 | ||
with: | ||
version: 1.4.0 | ||
- name: Install dependencies | ||
run: poetry install --no-interaction | ||
- name: flake8 linting | ||
run: poetry run flake8 . | ||
- name: black code formatting | ||
run: poetry run black . --check | ||
- name: mypy type checking | ||
run: poetry run mypy . | ||
- name: pytest | ||
run: poetry run pytest | ||
- name: build | ||
run: poetry build |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
from typing import Any, Literal | ||
|
||
import torch | ||
from torch import nn | ||
|
||
|
||
class InvertedLre(nn.Module): | ||
"""Low-rank inverted LRE, used for calculating subject activations from object activations""" | ||
|
||
relation: str | ||
subject_layer: int | ||
object_layer: int | ||
# store u, v, s, and bias separately to avoid storing the full weight matrix | ||
u: nn.Parameter | ||
s: nn.Parameter | ||
v: nn.Parameter | ||
bias: nn.Parameter | ||
object_aggregation: Literal["mean", "first_token"] | ||
metadata: dict[str, Any] | None = None | ||
|
||
def __init__( | ||
self, | ||
relation: str, | ||
subject_layer: int, | ||
object_layer: int, | ||
object_aggregation: Literal["mean", "first_token"], | ||
u: torch.Tensor, | ||
s: torch.Tensor, | ||
v: torch.Tensor, | ||
bias: torch.Tensor, | ||
metadata: dict[str, Any] | None = None, | ||
) -> None: | ||
super().__init__() | ||
self.relation = relation | ||
self.subject_layer = subject_layer | ||
self.object_layer = object_layer | ||
self.object_aggregation = object_aggregation | ||
self.u = nn.Parameter(u, requires_grad=False) | ||
self.s = nn.Parameter(s, requires_grad=False) | ||
self.v = nn.Parameter(v, requires_grad=False) | ||
self.bias = nn.Parameter(bias, requires_grad=False) | ||
self.metadata = metadata | ||
|
||
@property | ||
def rank(self) -> int: | ||
return self.s.shape[0] | ||
|
||
def w_inv_times_vec(self, vec: torch.Tensor) -> torch.Tensor: | ||
# group u.T @ vec to avoid calculating larger matrices than needed | ||
return self.v @ torch.diag(1 / self.s) @ (self.u.T @ vec) | ||
|
||
def forward( | ||
self, | ||
subject_activations: torch.Tensor, # a tensor of shape (num_activations, hidden_activation_size) | ||
normalize: bool = False, | ||
) -> torch.Tensor: | ||
return self.calculate_object_activation( | ||
subject_activations=subject_activations, | ||
normalize=normalize, | ||
) | ||
|
||
def calculate_subject_activation( | ||
self, | ||
object_activations: torch.Tensor, # a tensor of shape (num_activations, hidden_activation_size) | ||
normalize: bool = False, | ||
) -> torch.Tensor: | ||
# match precision of weight_inverse and bias | ||
unbiased_acts = object_activations - self.bias.unsqueeze(0) | ||
vec = self.w_inv_times_vec(unbiased_acts.T).mean(dim=1) | ||
|
||
if normalize: | ||
vec = vec / vec.norm() | ||
return vec | ||
|
||
|
||
class LowRankLre(nn.Module): | ||
"""Low-rank approximation of a LRE""" | ||
|
||
relation: str | ||
subject_layer: int | ||
object_layer: int | ||
# store u, v, s, and bias separately to avoid storing the full weight matrix | ||
u: nn.Parameter | ||
s: nn.Parameter | ||
v: nn.Parameter | ||
bias: nn.Parameter | ||
object_aggregation: Literal["mean", "first_token"] | ||
metadata: dict[str, Any] | None = None | ||
|
||
def __init__( | ||
self, | ||
relation: str, | ||
subject_layer: int, | ||
object_layer: int, | ||
object_aggregation: Literal["mean", "first_token"], | ||
u: torch.Tensor, | ||
s: torch.Tensor, | ||
v: torch.Tensor, | ||
bias: torch.Tensor, | ||
metadata: dict[str, Any] | None = None, | ||
) -> None: | ||
super().__init__() | ||
self.relation = relation | ||
self.subject_layer = subject_layer | ||
self.object_layer = object_layer | ||
self.object_aggregation = object_aggregation | ||
self.u = nn.Parameter(u, requires_grad=False) | ||
self.s = nn.Parameter(s, requires_grad=False) | ||
self.v = nn.Parameter(v, requires_grad=False) | ||
self.bias = nn.Parameter(bias, requires_grad=False) | ||
self.metadata = metadata | ||
|
||
@property | ||
def rank(self) -> int: | ||
return self.s.shape[0] | ||
|
||
def w_times_vec(self, vec: torch.Tensor) -> torch.Tensor: | ||
# group v.T @ vec to avoid calculating larger matrices than needed | ||
return self.u @ torch.diag(self.s) @ (self.v.T @ vec) | ||
|
||
def forward( | ||
self, | ||
subject_activations: torch.Tensor, # a tensor of shape (num_activations, hidden_activation_size) | ||
normalize: bool = False, | ||
) -> torch.Tensor: | ||
return self.calculate_object_activation( | ||
subject_activations=subject_activations, | ||
normalize=normalize, | ||
) | ||
|
||
def calculate_object_activation( | ||
self, | ||
subject_activations: torch.Tensor, # a tensor of shape (num_activations, hidden_activation_size) | ||
normalize: bool = False, | ||
) -> torch.Tensor: | ||
# match precision of weight_inverse and bias | ||
ws = self.w_times_vec(subject_activations.T) | ||
vec = (ws + self.bias.unsqueeze(-1)).mean(dim=1) | ||
if normalize: | ||
vec = vec / vec.norm() | ||
return vec | ||
|
||
|
||
class Lre(nn.Module): | ||
"""Linear Relational Embedding""" | ||
|
||
relation: str | ||
subject_layer: int | ||
object_layer: int | ||
weight: nn.Parameter | ||
bias: nn.Parameter | ||
object_aggregation: Literal["mean", "first_token"] | ||
metadata: dict[str, Any] | None = None | ||
|
||
def __init__( | ||
self, | ||
relation: str, | ||
subject_layer: int, | ||
object_layer: int, | ||
object_aggregation: Literal["mean", "first_token"], | ||
weight: torch.Tensor, | ||
bias: torch.Tensor, | ||
metadata: dict[str, Any] | None = None, | ||
) -> None: | ||
super().__init__() | ||
self.relation = relation | ||
self.subject_layer = subject_layer | ||
self.object_layer = object_layer | ||
self.object_aggregation = object_aggregation | ||
self.weight = nn.Parameter(weight, requires_grad=False) | ||
self.bias = nn.Parameter(bias, requires_grad=False) | ||
self.metadata = metadata | ||
|
||
def invert(self, rank: int) -> InvertedLre: | ||
"""Invert this LRE using a low-rank approximation""" | ||
u, s, v = self._low_rank_svd(rank) | ||
return InvertedLre( | ||
relation=self.relation, | ||
subject_layer=self.subject_layer, | ||
object_layer=self.object_layer, | ||
object_aggregation=self.object_aggregation, | ||
u=u.detach().clone(), | ||
s=s.detach().clone(), | ||
v=v.detach().clone(), | ||
bias=self.bias.detach().clone(), | ||
metadata=self.metadata, | ||
) | ||
|
||
def to_low_rank(self, rank: int) -> LowRankLre: | ||
"""Create a low-rank approximation of this LRE""" | ||
u, s, v = self._low_rank_svd(rank) | ||
return LowRankLre( | ||
relation=self.relation, | ||
subject_layer=self.subject_layer, | ||
object_layer=self.object_layer, | ||
object_aggregation=self.object_aggregation, | ||
u=u.detach().clone(), | ||
s=s.detach().clone(), | ||
v=v.detach().clone(), | ||
bias=self.bias.detach().clone(), | ||
metadata=self.metadata, | ||
) | ||
|
||
@torch.no_grad() | ||
def _low_rank_svd( | ||
self, rank: int | ||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
# use a float for the svd, then convert back to the original dtype | ||
u, s, v = torch.svd(self.weight.float()) | ||
low_rank_u: torch.Tensor = u[:, :rank].to(self.weight.dtype) | ||
low_rank_v: torch.Tensor = v[:, :rank].to(self.weight.dtype) | ||
low_rank_s: torch.Tensor = s[:rank].to(self.weight.dtype) | ||
return low_rank_u, low_rank_s, low_rank_v |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
|
||
@dataclass(frozen=True, slots=True) | ||
class Prompt: | ||
text: str | ||
answer: str | ||
subject: str | ||
subject_name: Optional[str] = None | ||
object_name: Optional[str] = None | ||
relation_name: Optional[str] = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import hashlib | ||
import os | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import Iterable, Optional | ||
|
||
from dataclasses_json import DataClassJsonMixin | ||
from tokenizers import Tokenizer | ||
from torch import nn | ||
|
||
from linear_relational.lib.torch_utils import guess_model_name | ||
from linear_relational.lib.verify_answers_match_expected import ( | ||
verify_answers_match_expected, | ||
) | ||
from linear_relational.Prompt import Prompt | ||
|
||
|
||
@dataclass | ||
class MatchingPromptsCache(DataClassJsonMixin): | ||
cache: dict[str, bool] | ||
model_name: str | ||
|
||
|
||
class PromptValidator: | ||
""" | ||
Helper class to filter prompts that match a given list of tokens. | ||
This class handles caching results to avoid repeating work between runs. | ||
""" | ||
|
||
_cache: MatchingPromptsCache | ||
cache_file: str | Path | None | ||
model: nn.Module | ||
tokenizer: Tokenizer | ||
|
||
def __init__( | ||
self, | ||
model: nn.Module, | ||
tokenizer: Tokenizer, | ||
cache_file: Optional[str | Path] = None, | ||
load_saved_cache: bool = True, | ||
) -> None: | ||
self.model = model | ||
self.tokenizer = tokenizer | ||
model_name = guess_model_name(model) | ||
self.cache_file = cache_file | ||
if cache_file and load_saved_cache and os.path.exists(cache_file): | ||
with open(cache_file, "r") as f: | ||
self._cache = MatchingPromptsCache.from_json(f.read()) | ||
if self._cache.model_name != model_name: | ||
raise ValueError( | ||
f"Cache file {cache_file} was generated by a different model" | ||
) | ||
else: | ||
self._cache = MatchingPromptsCache(cache={}, model_name=model_name) | ||
|
||
def write_cache(self, cache_file: Optional[str | Path] = None) -> None: | ||
_cache_file = cache_file or self.cache_file | ||
if _cache_file is None: | ||
raise ValueError("No cache file was provided") | ||
with open(_cache_file, "w") as f: | ||
f.write(self._cache.to_json()) | ||
|
||
def _is_cached(self, prompt: Prompt) -> bool: | ||
key = cache_key(prompt.text, prompt.answer) | ||
return key in self._cache.cache | ||
|
||
def _prompt_matches(self, prompt: Prompt) -> bool: | ||
key = cache_key(prompt.text, prompt.answer) | ||
return self._cache.cache[key] | ||
|
||
def filter_prompts( | ||
self, | ||
prompts: Iterable[Prompt], | ||
batch_size: int = 8, | ||
show_progress: bool = False, | ||
) -> list[Prompt]: | ||
uncached_prompts = [prompt for prompt in prompts if not self._is_cached(prompt)] | ||
if len(uncached_prompts) > 0: | ||
answer_match_results = verify_answers_match_expected( | ||
model=self.model, | ||
tokenizer=self.tokenizer, | ||
prompts=[prompt.text for prompt in uncached_prompts], | ||
expected_answers=[prompt.answer for prompt in uncached_prompts], | ||
batch_size=batch_size, | ||
show_progress=show_progress, | ||
) | ||
for prompt, match_result in zip(uncached_prompts, answer_match_results): | ||
key = cache_key(prompt.text, prompt.answer) | ||
self._cache.cache[key] = match_result.answer_matches_expected | ||
return [prompt for prompt in prompts if self._prompt_matches(prompt)] | ||
|
||
|
||
def cache_key(prompt_text: str, answer: str) -> str: | ||
# return a md5 hash of the prompt text and answer | ||
return hashlib.md5((prompt_text + answer).encode("utf-8")).hexdigest()[:15] |
Empty file.
Oops, something went wrong.