Skip to content

Commit 04c63b4

Browse files
committed
Merge branch 'dev-3.x' into refactor-weight-conversion
2 parents 1e1ac75 + 9e63d2c commit 04c63b4

File tree

8 files changed

+587
-653
lines changed

8 files changed

+587
-653
lines changed

docs/make_docs.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import pandas as pd
99

10-
from transformer_lens import loading
10+
from transformer_lens import loading, supported_models
1111

1212
# Docs Directories
1313
CURRENT_DIR = Path(__file__).parent
@@ -94,10 +94,13 @@ def generate_model_table(_app: Optional[Any] = None):
9494
]
9595
df = pd.DataFrame(
9696
{
97-
name: [get_property(name, model_name) for model_name in loading.DEFAULT_MODEL_ALIASES]
97+
name: [
98+
get_property(name, model_name)
99+
for model_name in supported_models.DEFAULT_MODEL_ALIASES
100+
]
98101
for name in column_names
99102
},
100-
index=loading.DEFAULT_MODEL_ALIASES,
103+
index=supported_models.DEFAULT_MODEL_ALIASES,
101104
)
102105

103106
# Convert to markdown (with a title)

tests/unit/test_supported_models.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from transformer_lens.supported_models import MODEL_ALIASES, OFFICIAL_MODEL_NAMES
2+
3+
4+
def test_official_model_names_is_alphabetical():
5+
assert OFFICIAL_MODEL_NAMES == sorted(
6+
OFFICIAL_MODEL_NAMES, key=str.casefold
7+
), "OFFICIAL_MODEL_NAMES are not alphabetical"
8+
9+
10+
def test_model_aliases_is_alphabetical():
11+
# Extract the keys as they appear in the dictionary
12+
actual_keys = list(MODEL_ALIASES.keys())
13+
14+
# Create a sorted version, ignoring case
15+
expected_keys = sorted(actual_keys, key=str.casefold)
16+
17+
# Compare the actual insertion order to the expected alphabetical order
18+
assert actual_keys == expected_keys, "MODEL_ALIASES keys are not in alphabetical order. "

transformer_lens/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from . import supported_models
12
from . import utilities
23
from . import hook_points
34
from . import evals

0 commit comments

Comments
 (0)