Skip to content

Commit 401cc78

Browse files
authored
Bug new loading (#891)
* tested gpt2 loading * fixed gemma 1 bugs * ran format
1 parent fb12aff commit 401cc78

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

transformer_lens/weight_conversion/conversion_utils/conversion_steps/rearrange_weight_conversion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(
1919
self.pattern = pattern
2020
self.axes_lengths = axes_lengths
2121

22-
def handle_conversion(self, input_value: torch.Tensor) -> torch.Tensor:
22+
def handle_conversion(self, input_value: torch.Tensor, *full_context) -> torch.Tensor:
2323
return einops.rearrange(input_value, self.pattern, **self.axes_lengths)
2424

2525
def __repr__(self):

transformer_lens/weight_conversion/conversion_utils/conversion_steps/repeat_weight_conversion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(
1919
self.pattern = pattern
2020
self.axes_lengths = axes_lengths
2121

22-
def handle_conversion(self, input_value: torch.Tensor) -> torch.Tensor:
22+
def handle_conversion(self, input_value: torch.Tensor, *full_context) -> torch.Tensor:
2323
return einops.repeat(input_value, self.pattern, **self.axes_lengths)
2424

2525
def __repr__(self):

transformer_lens/weight_conversion/conversion_utils/conversion_steps/weight_conversion_set.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def process_conversion(
4545
self, input_value, remote_field: str, conversion: BaseWeightConversion, *full_context
4646
):
4747
field = find_property(remote_field, input_value)
48-
if isinstance(field, WeightConversionSet):
48+
if isinstance(conversion, WeightConversionSet):
4949
result = []
5050
for layer in field:
5151
result.append(conversion.convert(layer, input_value, *full_context))

transformer_lens/weight_conversion/gemma.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
class GemmaWeightNormalizationConversion(BaseWeightConversion):
16-
def convert(self, input_value):
16+
def convert(self, input_value, *full_context):
1717
return input_value.float() + torch.ones_like(input_value, dtype=torch.float32)
1818

1919
def __repr__(self):
@@ -24,7 +24,7 @@ class GemmaWeightConversion(ArchitectureConversion):
2424
def __init__(self, cfg: HookedTransformerConfig) -> None:
2525
super().__init__(
2626
{
27-
"unembed.W_U": "model.lm_head.weight.T",
27+
"unembed.W_U": "lm_head.weight.T",
2828
"unembed.b_U": torch.zeros(cfg.d_vocab),
2929
"ln_final.w": (
3030
"model.norm.weight",
@@ -100,4 +100,4 @@ def normalization_before_and_after_conversions(self) -> FIELD_SET:
100100
}
101101

102102
def standard_normalization_conversions(self) -> FIELD_SET:
103-
return {"ln2.w": ("pre_feedforward_layernorm.weight", GemmaWeightNormalizationConversion())}
103+
return {"ln2.w": ("post_attention_layernorm.weight", GemmaWeightNormalizationConversion())}

0 commit comments

Comments
 (0)