Skip to content

Commit

Permalink
improve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
catwell committed Aug 26, 2024
1 parent 10dfa73 commit 2fc93cb
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/refiners/foundationals/swin/mvanet/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def convert_weights(official_state_dict: dict[str, Tensor]) -> dict[str, Tensor]
r"multifieldcrossatt.attention.5",
r"dec_blk\d+\.linear[12]",
r"dec_blk[1234]\.attention\.[4567]",
# We don't need the sideout weights
# We don't need the sideout weights for inference
r"sideout\d+",
]
state_dict = {k: v for k, v in official_state_dict.items() if not any(re.match(rm, k) for rm in rm_list)}
Expand Down
2 changes: 2 additions & 0 deletions src/refiners/foundationals/swin/mvanet/mclm.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ def __init__(
positional_embedding = PositionEmbeddingSine(num_pos_feats=emb_dim // 2, device=device)

# LayerNorms in MCLM share their weights.
# We use the `proxy` trick below so they can be present only
# once in the tree but called in two different places.

ln1 = fl.LayerNorm(emb_dim, device=device)
ln2 = fl.LayerNorm(emb_dim, device=device)
Expand Down

0 comments on commit 2fc93cb

Please sign in to comment.