Skip to content

Commit

Permalink
adding initial libs and classes
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Nov 16, 2023
1 parent 2ccf587 commit 2691e8f
Show file tree
Hide file tree
Showing 30 changed files with 3,513 additions and 1 deletion.
26 changes: 26 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: CI
on: [push]
jobs:
lint_test_and_build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: "3.11"
- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: 1.4.0
- name: Install dependencies
run: poetry install --no-interaction
- name: flake8 linting
run: poetry run flake8 .
- name: black code formatting
run: poetry run black . --check
- name: mypy type checking
run: poetry run mypy .
- name: pytest
run: poetry run pytest
- name: build
run: poetry build
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dist/
downloads/
eggs/
.eggs/
lib/
# lib/
lib64/
parts/
sdist/
Expand Down Expand Up @@ -158,3 +158,9 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# misc
.vscode/
.DS_Store
*.pt
lightning_logs/
Empty file added README.md
Empty file.
213 changes: 213 additions & 0 deletions linear_relational/Lre.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
from typing import Any, Literal

import torch
from torch import nn


class InvertedLre(nn.Module):
"""Low-rank inverted LRE, used for calculating subject activations from object activations"""

relation: str
subject_layer: int
object_layer: int
# store u, v, s, and bias separately to avoid storing the full weight matrix
u: nn.Parameter
s: nn.Parameter
v: nn.Parameter
bias: nn.Parameter
object_aggregation: Literal["mean", "first_token"]
metadata: dict[str, Any] | None = None

def __init__(
self,
relation: str,
subject_layer: int,
object_layer: int,
object_aggregation: Literal["mean", "first_token"],
u: torch.Tensor,
s: torch.Tensor,
v: torch.Tensor,
bias: torch.Tensor,
metadata: dict[str, Any] | None = None,
) -> None:
super().__init__()
self.relation = relation
self.subject_layer = subject_layer
self.object_layer = object_layer
self.object_aggregation = object_aggregation
self.u = nn.Parameter(u, requires_grad=False)
self.s = nn.Parameter(s, requires_grad=False)
self.v = nn.Parameter(v, requires_grad=False)
self.bias = nn.Parameter(bias, requires_grad=False)
self.metadata = metadata

@property
def rank(self) -> int:
return self.s.shape[0]

def w_inv_times_vec(self, vec: torch.Tensor) -> torch.Tensor:
# group u.T @ vec to avoid calculating larger matrices than needed
return self.v @ torch.diag(1 / self.s) @ (self.u.T @ vec)

def forward(
self,
subject_activations: torch.Tensor, # a tensor of shape (num_activations, hidden_activation_size)
normalize: bool = False,
) -> torch.Tensor:
return self.calculate_object_activation(
subject_activations=subject_activations,
normalize=normalize,
)

def calculate_subject_activation(
self,
object_activations: torch.Tensor, # a tensor of shape (num_activations, hidden_activation_size)
normalize: bool = False,
) -> torch.Tensor:
# match precision of weight_inverse and bias
unbiased_acts = object_activations - self.bias.unsqueeze(0)
vec = self.w_inv_times_vec(unbiased_acts.T).mean(dim=1)

if normalize:
vec = vec / vec.norm()
return vec


class LowRankLre(nn.Module):
"""Low-rank approximation of a LRE"""

relation: str
subject_layer: int
object_layer: int
# store u, v, s, and bias separately to avoid storing the full weight matrix
u: nn.Parameter
s: nn.Parameter
v: nn.Parameter
bias: nn.Parameter
object_aggregation: Literal["mean", "first_token"]
metadata: dict[str, Any] | None = None

def __init__(
self,
relation: str,
subject_layer: int,
object_layer: int,
object_aggregation: Literal["mean", "first_token"],
u: torch.Tensor,
s: torch.Tensor,
v: torch.Tensor,
bias: torch.Tensor,
metadata: dict[str, Any] | None = None,
) -> None:
super().__init__()
self.relation = relation
self.subject_layer = subject_layer
self.object_layer = object_layer
self.object_aggregation = object_aggregation
self.u = nn.Parameter(u, requires_grad=False)
self.s = nn.Parameter(s, requires_grad=False)
self.v = nn.Parameter(v, requires_grad=False)
self.bias = nn.Parameter(bias, requires_grad=False)
self.metadata = metadata

@property
def rank(self) -> int:
return self.s.shape[0]

def w_times_vec(self, vec: torch.Tensor) -> torch.Tensor:
# group v.T @ vec to avoid calculating larger matrices than needed
return self.u @ torch.diag(self.s) @ (self.v.T @ vec)

def forward(
self,
subject_activations: torch.Tensor, # a tensor of shape (num_activations, hidden_activation_size)
normalize: bool = False,
) -> torch.Tensor:
return self.calculate_object_activation(
subject_activations=subject_activations,
normalize=normalize,
)

def calculate_object_activation(
self,
subject_activations: torch.Tensor, # a tensor of shape (num_activations, hidden_activation_size)
normalize: bool = False,
) -> torch.Tensor:
# match precision of weight_inverse and bias
ws = self.w_times_vec(subject_activations.T)
vec = (ws + self.bias.unsqueeze(-1)).mean(dim=1)
if normalize:
vec = vec / vec.norm()
return vec


class Lre(nn.Module):
"""Linear Relational Embedding"""

relation: str
subject_layer: int
object_layer: int
weight: nn.Parameter
bias: nn.Parameter
object_aggregation: Literal["mean", "first_token"]
metadata: dict[str, Any] | None = None

def __init__(
self,
relation: str,
subject_layer: int,
object_layer: int,
object_aggregation: Literal["mean", "first_token"],
weight: torch.Tensor,
bias: torch.Tensor,
metadata: dict[str, Any] | None = None,
) -> None:
super().__init__()
self.relation = relation
self.subject_layer = subject_layer
self.object_layer = object_layer
self.object_aggregation = object_aggregation
self.weight = nn.Parameter(weight, requires_grad=False)
self.bias = nn.Parameter(bias, requires_grad=False)
self.metadata = metadata

def invert(self, rank: int) -> InvertedLre:
"""Invert this LRE using a low-rank approximation"""
u, s, v = self._low_rank_svd(rank)
return InvertedLre(
relation=self.relation,
subject_layer=self.subject_layer,
object_layer=self.object_layer,
object_aggregation=self.object_aggregation,
u=u.detach().clone(),
s=s.detach().clone(),
v=v.detach().clone(),
bias=self.bias.detach().clone(),
metadata=self.metadata,
)

def to_low_rank(self, rank: int) -> LowRankLre:
"""Create a low-rank approximation of this LRE"""
u, s, v = self._low_rank_svd(rank)
return LowRankLre(
relation=self.relation,
subject_layer=self.subject_layer,
object_layer=self.object_layer,
object_aggregation=self.object_aggregation,
u=u.detach().clone(),
s=s.detach().clone(),
v=v.detach().clone(),
bias=self.bias.detach().clone(),
metadata=self.metadata,
)

@torch.no_grad()
def _low_rank_svd(
self, rank: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# use a float for the svd, then convert back to the original dtype
u, s, v = torch.svd(self.weight.float())
low_rank_u: torch.Tensor = u[:, :rank].to(self.weight.dtype)
low_rank_v: torch.Tensor = v[:, :rank].to(self.weight.dtype)
low_rank_s: torch.Tensor = s[:rank].to(self.weight.dtype)
return low_rank_u, low_rank_s, low_rank_v
12 changes: 12 additions & 0 deletions linear_relational/Prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from dataclasses import dataclass
from typing import Optional


@dataclass(frozen=True, slots=True)
class Prompt:
text: str
answer: str
subject: str
subject_name: Optional[str] = None
object_name: Optional[str] = None
relation_name: Optional[str] = None
95 changes: 95 additions & 0 deletions linear_relational/PromptValidator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import hashlib
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Optional

from dataclasses_json import DataClassJsonMixin
from tokenizers import Tokenizer
from torch import nn

from linear_relational.lib.torch_utils import guess_model_name
from linear_relational.lib.verify_answers_match_expected import (
verify_answers_match_expected,
)
from linear_relational.Prompt import Prompt


@dataclass
class MatchingPromptsCache(DataClassJsonMixin):
cache: dict[str, bool]
model_name: str


class PromptValidator:
"""
Helper class to filter prompts that match a given list of tokens.
This class handles caching results to avoid repeating work between runs.
"""

_cache: MatchingPromptsCache
cache_file: str | Path | None
model: nn.Module
tokenizer: Tokenizer

def __init__(
self,
model: nn.Module,
tokenizer: Tokenizer,
cache_file: Optional[str | Path] = None,
load_saved_cache: bool = True,
) -> None:
self.model = model
self.tokenizer = tokenizer
model_name = guess_model_name(model)
self.cache_file = cache_file
if cache_file and load_saved_cache and os.path.exists(cache_file):
with open(cache_file, "r") as f:
self._cache = MatchingPromptsCache.from_json(f.read())
if self._cache.model_name != model_name:
raise ValueError(
f"Cache file {cache_file} was generated by a different model"
)
else:
self._cache = MatchingPromptsCache(cache={}, model_name=model_name)

def write_cache(self, cache_file: Optional[str | Path] = None) -> None:
_cache_file = cache_file or self.cache_file
if _cache_file is None:
raise ValueError("No cache file was provided")
with open(_cache_file, "w") as f:
f.write(self._cache.to_json())

def _is_cached(self, prompt: Prompt) -> bool:
key = cache_key(prompt.text, prompt.answer)
return key in self._cache.cache

def _prompt_matches(self, prompt: Prompt) -> bool:
key = cache_key(prompt.text, prompt.answer)
return self._cache.cache[key]

def filter_prompts(
self,
prompts: Iterable[Prompt],
batch_size: int = 8,
show_progress: bool = False,
) -> list[Prompt]:
uncached_prompts = [prompt for prompt in prompts if not self._is_cached(prompt)]
if len(uncached_prompts) > 0:
answer_match_results = verify_answers_match_expected(
model=self.model,
tokenizer=self.tokenizer,
prompts=[prompt.text for prompt in uncached_prompts],
expected_answers=[prompt.answer for prompt in uncached_prompts],
batch_size=batch_size,
show_progress=show_progress,
)
for prompt, match_result in zip(uncached_prompts, answer_match_results):
key = cache_key(prompt.text, prompt.answer)
self._cache.cache[key] = match_result.answer_matches_expected
return [prompt for prompt in prompts if self._prompt_matches(prompt)]


def cache_key(prompt_text: str, answer: str) -> str:
# return a md5 hash of the prompt text and answer
return hashlib.md5((prompt_text + answer).encode("utf-8")).hexdigest()[:15]
Empty file added linear_relational/__init__.py
Empty file.
Loading

0 comments on commit 2691e8f

Please sign in to comment.