Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversion improved test coverage #886

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,33 +28,33 @@
)
def test_arithmetic_operations(operation, value, input_tensor, expected_output):
conversion = ArithmeticWeightConversion(operation, value)
output = conversion.handle_conversion(input_tensor)
output = conversion.convert(input_tensor)
assert torch.allclose(output, expected_output), f"Expected {expected_output}, but got {output}"


def test_scalar_operations():
conversion = ArithmeticWeightConversion(OperationTypes.MULTIPLICATION, 10)
assert conversion.handle_conversion(5) == 50
assert conversion.convert(5) == 50


def test_tensor_operations():
input_tensor = torch.tensor([1.0, 2.0, 3.0])
conversion = ArithmeticWeightConversion(OperationTypes.ADDITION, torch.tensor([1.0, 1.0, 1.0]))
expected_output = torch.tensor([2.0, 3.0, 4.0])
assert torch.allclose(conversion.handle_conversion(input_tensor), expected_output)
assert torch.allclose(conversion.convert(input_tensor), expected_output)


def test_input_filter():
def input_filter(x):
return x * 2 # Double the input before applying the operation

conversion = ArithmeticWeightConversion(OperationTypes.ADDITION, 3, input_filter=input_filter)
assert conversion.handle_conversion(torch.tensor(2.0)) == 7.0 # (2 * 2) + 3 = 7
assert conversion.convert(torch.tensor(2.0)) == 7.0 # (2 * 2) + 3 = 7


def test_output_filter():
def output_filter(x):
return x / 2 # Halve the result after applying the operation

conversion = ArithmeticWeightConversion(OperationTypes.ADDITION, 3, output_filter=output_filter)
assert conversion.handle_conversion(torch.tensor(2.0)) == 2.5 # (2 + 3) / 2 = 2.5
assert conversion.convert(torch.tensor(2.0)) == 2.5 # (2 + 3) / 2 = 2.5
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,44 @@ def handle_conversion(self, weight):
return weight + 5


def test_process_weight_conversion_applies_conversion():
weight_conversion = MockWeightConversion()
weight = torch.zeros(2, 2)

converted_weight = weight_conversion.process_weight_conversion(weight)

expected = torch.tensor((2, 2)) + 5

assert torch.all(converted_weight == expected)


def test_process_weight_conversion_applies_filters():
def mock_input_filter(weight):
return weight * 2

def mock_output_filter(weight):
return weight - 1

weight_conversion = MockWeightConversion(
input_filter=mock_input_filter, output_filter=mock_output_filter
)

weight = torch.tensor([[1.0, 2.0], [3.0, 4.0]])

converted_weight = weight_conversion.process_weight_conversion(weight)

expected_weight = (weight * 2) + 5 - 1
assert torch.allclose(converted_weight, expected_weight)


def test_base_weight_conversion_convert_throws_error():
weight_conversion = BaseWeightConversion()
with pytest.raises(NotImplementedError):
weight_conversion.convert(torch.zeros(1, 4))

def test_mock_weight_conversion_adds_five():
"""
Verify that the mock subclass adds 5 to every element of the tensor.
"""
weight_conversion = MockWeightConversion()
input_tensor = torch.zeros((1, 4), dtype=torch.float32)
output_tensor = weight_conversion.convert(input_tensor)
expected_tensor = torch.full((1, 4), 5.0, dtype=torch.float32)

# Option 1: simple equality check
assert torch.equal(output_tensor, expected_tensor)

# Option 2: more robust approximate check
# torch.testing.assert_close(output_tensor, expected_tensor)

@pytest.mark.parametrize("shape", [(1, 4), (2, 2), (3,)])
def test_mock_weight_conversion_various_shapes(shape):
"""
Test multiple shapes to ensure .convert() works for different dims.
"""
weight_conversion = MockWeightConversion()
input_tensor = torch.zeros(shape)
output_tensor = weight_conversion.convert(input_tensor)
expected_tensor = torch.full(shape, 5.0)
assert torch.equal(output_tensor, expected_tensor)

def test_mock_weight_conversion_empty_tensor():
"""
Ensure code doesn't crash on an empty tensor.
"""
weight_conversion = MockWeightConversion()
input_tensor = torch.zeros((0, 4))
output_tensor = weight_conversion.convert(input_tensor)
assert output_tensor.shape == (0, 4)
# Since shape is (0,4), we can just check shape correctness
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from transformer_lens.weight_conversion.conversion_utils.model_search import (
from transformer_lens.weight_conversion.conversion_utils.conversion_helpers import (
find_property,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torch import nn

from .conversion_steps.base_weight_conversion import FIELD_SET
from .conversion_steps.types import FIELD_SET
from .conversion_steps.weight_conversion_set import WeightConversionSet


Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@


def find_property(needle: str, haystack):
needle_levels = needle.split(".")
first_key = needle_levels.pop(0)
Expand All @@ -8,3 +10,4 @@ def find_property(needle: str, haystack):
return find_property(".".join(needle_levels), current_level)

return current_level

Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .base_weight_conversion import BaseWeightConversion, FIELD_SET
from .types import CONVERSION_ACTION, CONVERSION, FIELD_SET
from .base_weight_conversion import BaseWeightConversion
from .callable_weight_conversion import CallableWeightConversion
from .arithmetic_weight_conversion import ArithmeticWeightConversion, OperationTypes
from .rearrange_weight_conversion import RearrangeWeightConversion
from .repeat_weight_conversion import RepeatWeightConversion
from .weight_conversion_set import WeightConversionSet
from .ternary_weight_conversion import TernaryWeightConversion
from .zeros_like_conversion import ZerosLikeConversion
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from .base_weight_conversion import BaseWeightConversion
from transformer_lens.weight_conversion.conversion_utils.conversion_steps.base_weight_conversion import BaseWeightConversion


class OperationTypes(Enum):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
from collections.abc import Callable
from typing import Optional

import torch

from transformer_lens.weight_conversion.conversion_utils.model_search import (
find_property,
)

CONVERSION = tuple[str, "BaseWeightConversion"]
CONVERSION_ACTION = torch.Tensor | str | CONVERSION
FIELD_SET = dict[CONVERSION_ACTION]


class BaseWeightConversion:
def __init__(
Expand All @@ -26,22 +16,6 @@ def convert(self, input_value):
output = self.handle_conversion(input_value)
return self.output_filter(output) if self.output_filter is not None else output

def process_weight_conversion(self, input_value, conversion_details: CONVERSION_ACTION):
if isinstance(conversion_details, torch.Tensor):
return conversion_details
elif isinstance(conversion_details, str):
return find_property(conversion_details, input_value)
else:
(remote_field, conversion) = conversion_details
weight = find_property(remote_field, input_value)
if isinstance(conversion, "WeightConversionSet"):
result = []
for layer in weight:
result.append(conversion.convert(layer))
return result

else:
return conversion.convert(weight)

def handle_conversion(self, input_value):
raise NotImplementedError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import torch

from .base_weight_conversion import CONVERSION_ACTION, BaseWeightConversion
from .types import CONVERSION_ACTION
from .base_weight_conversion import BaseWeightConversion

PRIMARY_CONVERSION = torch.Tensor | BaseWeightConversion | None

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch

from .base_weight_conversion import BaseWeightConversion

# This type is used to indicate the position of a field in the remote model
REMOTE_FIELD = str
# This is the typing for a weight conversion when operations are needed on the REMOTE_FIELD.
# The BaseWeightConversion will be the instructions on the operations needed to bring the field
# into TransformerLens
CONVERSION = tuple[REMOTE_FIELD, BaseWeightConversion]
# This is the full range of actions that can be taken to bring a field into TransformerLens
# These can be configured as a predefined tensor, or a direction copy of the REMOTE_FIELD into
# TransformerLens, or a more in depth CONVERSION
CONVERSION_ACTION = torch.Tensor | REMOTE_FIELD | CONVERSION
# This type is for a full set of conversions from a remote model into TransformerLens. Each key in
# this dictionary will correspond to a field within a TransformerLens module, and each
# CONVERSION_ACTION will instruction TransformerLens on how to bring the field into a
# TransformerLens model. This type is repeated in both the root level of a model, as well as any
# layers within the model
FIELD_SET = dict[str, CONVERSION_ACTION]
Original file line number Diff line number Diff line change
@@ -1,29 +1,49 @@
from collections.abc import Callable
from typing import Optional
import torch

from .base_weight_conversion import FIELD_SET, BaseWeightConversion
from .types import FIELD_SET, CONVERSION_ACTION, CONVERSION
from transformer_lens.weight_conversion.conversion_utils.conversion_helpers import find_property
from .base_weight_conversion import BaseWeightConversion
from transformer_lens.weight_conversion.conversion_utils.weight_conversion_utils import WeightConversionUtils


class WeightConversionSet(BaseWeightConversion):
def __init__(
self,
weights: FIELD_SET,
input_filter: Optional[Callable] = None,
output_filter: Optional[Callable] = None,
):
super().__init__(input_filter=input_filter, output_filter=output_filter)
super().__init__()
self.weights = weights

def handle_conversion(self, input_value):
result = {}
for weight_name in self.weights:
result[weight_name] = super().process_weight_conversion(
result[weight_name] = self.process_conversion_action(
input_value,
conversion_details=self.weights[weight_name],
)

return result

def process_conversion_action(self, input_value, conversion_details: CONVERSION_ACTION):
if isinstance(conversion_details, torch.Tensor):
return conversion_details
elif isinstance(conversion_details, str):
return find_property(conversion_details, input_value)
else:
(remote_field, conversion) = conversion_details
return self.process_conversion(input_value, remote_field, conversion)

def process_conversion(self, input_value, remote_field: str, conversion: CONVERSION):
field = find_property(remote_field, input_value)
if isinstance(field, WeightConversionSet):
result = []
for layer in field:
result.append(conversion.convert(layer))
return result

else:
return conversion.convert(field)

def __repr__(self):
conversion_string = (
"Is composed of a set of nested conversions with the following details {\n\t"
Expand Down

This file was deleted.

Loading