Skip to content

Commit

Permalink
Add Multi-View Aggregation Network (MVANet)
Browse files Browse the repository at this point in the history
Co-authored-by: Pierre Colle <[email protected]>
  • Loading branch information
catwell and piercus committed Aug 26, 2024
1 parent 58c1cc7 commit 10dfa73
Show file tree
Hide file tree
Showing 19 changed files with 1,525 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docs/reference/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
* [<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> DINOv2](foundationals/dinov2.md)
* [<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Latent Diffusion](foundationals/latent_diffusion.md)
* [<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Segment Anything](foundationals/segment_anything.md)

* [<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Swin Transformers](foundationals/swin.md)
2 changes: 2 additions & 0 deletions docs/reference/foundationals/swin.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
::: refiners.foundationals.swin.swin_transformer
::: refiners.foundationals.swin.mvanet
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ conversion = [
"segment-anything-py>=1.0",
"requests>=2.26.0",
"tqdm>=4.62.3",
"gdown>=5.2.0",
]
doc = [
# required by mkdocs to format the signatures
Expand Down
11 changes: 11 additions & 0 deletions requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ babel==2.15.0
# via mkdocs-material
backports-strenum==1.3.1
# via griffe
beautifulsoup4==4.12.3
# via gdown
bitsandbytes==0.43.3
# via refiners
black==24.4.2
Expand Down Expand Up @@ -70,6 +72,7 @@ docker-pycreds==0.4.0
filelock==3.15.4
# via datasets
# via diffusers
# via gdown
# via huggingface-hub
# via torch
# via transformers
Expand All @@ -85,6 +88,8 @@ fsspec==2024.5.0
# via torch
future==1.0.0
# via neptune
gdown==5.2.0
# via refiners
ghp-import==2.1.0
# via mkdocs
gitdb==4.0.11
Expand Down Expand Up @@ -274,6 +279,8 @@ pyjwt==2.9.0
pymdown-extensions==10.9
# via mkdocs-material
# via mkdocstrings
pysocks==1.7.1
# via requests
python-dateutil==2.9.0.post0
# via arrow
# via botocore
Expand Down Expand Up @@ -311,6 +318,7 @@ requests==2.32.3
# via bravado-core
# via datasets
# via diffusers
# via gdown
# via huggingface-hub
# via mkdocs-material
# via neptune
Expand Down Expand Up @@ -356,6 +364,8 @@ six==1.16.0
# via rfc3339-validator
smmap==5.0.1
# via gitdb
soupsieve==2.6
# via beautifulsoup4
swagger-spec-validator==3.0.4
# via bravado-core
# via neptune
Expand Down Expand Up @@ -383,6 +393,7 @@ torchvision==0.19.0
# via timm
tqdm==4.66.4
# via datasets
# via gdown
# via huggingface-hub
# via refiners
# via transformers
Expand Down
40 changes: 40 additions & 0 deletions scripts/conversion/convert_mvanet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import argparse
from pathlib import Path

from refiners.fluxion.utils import load_tensors, save_to_safetensors
from refiners.foundationals.swin.mvanet.converter import convert_weights


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--from",
type=str,
required=True,
dest="source_path",
help="A MVANet checkpoint. One can be found at https://github.com/qianyu-dlut/MVANet",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default=None,
help=(
"Path to save the converted model. If not specified, the output path will be the source path with the"
" extension changed to .safetensors."
),
)
parser.add_argument("--half", action="store_true", dest="half")
args = parser.parse_args()

src_weights = load_tensors(args.source_path)
weights = convert_weights(src_weights)
if args.half:
weights = {key: value.half() for key, value in weights.items()}
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}.safetensors"
save_to_safetensors(path=args.output_path, tensors=weights)


if __name__ == "__main__":
main()
32 changes: 32 additions & 0 deletions scripts/prepare_test_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import sys
from urllib.parse import urlparse

import gdown
import requests
from tqdm import tqdm

Expand Down Expand Up @@ -446,6 +447,25 @@ def download_ic_light():
)


def download_mvanet():
fn = "Model_80.pth"
dest_folder = os.path.join(test_weights_dir, "mvanet")
dest_filename = os.path.join(dest_folder, fn)

if os.environ.get("DRY_RUN") == "1":
return

if os.path.exists(dest_filename):
print(f"✖️ ️ Skipping previously downloaded mvanet/{fn}")
else:
os.makedirs(dest_folder, exist_ok=True)
print(f"🔽 Downloading mvanet/{fn} => '{rel(dest_filename)}'", end="\n")
gdown.download(id="1_gabQXOF03MfXnf3EWDK1d_8wKiOemOv", output=dest_filename, quiet=True)
print(f"{previous_line}✅ Downloaded mvanet/{fn} => '{rel(dest_filename)}' ")

check_hash(dest_filename, "b915d492")


def printg(msg: str):
"""print in green color"""
print("\033[92m" + msg + "\033[0m")
Expand Down Expand Up @@ -808,6 +828,16 @@ def convert_ic_light():
)


def convert_mvanet():
run_conversion_script(
"convert_mvanet.py",
"tests/weights/mvanet/Model_80.pth",
"tests/weights/mvanet/mvanet.safetensors",
half=True,
expected_hash="bf9ae4cb",
)


def download_all():
print(f"\nAll weights will be downloaded to {test_weights_dir}\n")
download_sd15("runwayml/stable-diffusion-v1-5")
Expand All @@ -830,6 +860,7 @@ def download_all():
download_sdxl_lightning_base()
download_sdxl_lightning_lora()
download_ic_light()
download_mvanet()


def convert_all():
Expand All @@ -850,6 +881,7 @@ def convert_all():
convert_lcm_base()
convert_sdxl_lightning_base()
convert_ic_light()
convert_mvanet()


def main():
Expand Down
3 changes: 3 additions & 0 deletions src/refiners/foundationals/swin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .swin_transformer import SwinTransformer

__all__ = ["SwinTransformer"]
3 changes: 3 additions & 0 deletions src/refiners/foundationals/swin/mvanet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .mvanet import MVANet

__all__ = ["MVANet"]
138 changes: 138 additions & 0 deletions src/refiners/foundationals/swin/mvanet/converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import re

from torch import Tensor


def convert_weights(official_state_dict: dict[str, Tensor]) -> dict[str, Tensor]:
rm_list = [
# Official weights contains useless keys
# See https://github.com/qianyu-dlut/MVANet/issues/3#issuecomment-2105650425
r"multifieldcrossatt.linear[56]",
r"multifieldcrossatt.attention.5",
r"dec_blk\d+\.linear[12]",
r"dec_blk[1234]\.attention\.[4567]",
# We don't need the sideout weights
r"sideout\d+",
]
state_dict = {k: v for k, v in official_state_dict.items() if not any(re.match(rm, k) for rm in rm_list)}

keys_map: dict[str, str] = {}
for k in state_dict.keys():
v: str = k

def rpfx(s: str, src: str, dst: str) -> str:
if not s.startswith(src):
return s
return s.replace(src, dst, 1)

# Swin Transformer backbone

v = rpfx(v, "backbone.patch_embed.proj.", "SwinTransformer.PatchEmbedding.Conv2d.")
v = rpfx(v, "backbone.patch_embed.norm.", "SwinTransformer.PatchEmbedding.LayerNorm.")

if m := re.match(r"backbone\.layers\.(\d+)\.downsample\.(.*)", v):
s = m.group(2).replace("reduction.", "Linear.").replace("norm.", "LayerNorm.")
v = f"SwinTransformer.Chain_{int(m.group(1)) + 1}.PatchMerging.{s}"

if m := re.match(r"backbone\.layers\.(\d+)\.blocks\.(\d+)\.(.*)", v):
s = m.group(3)
s = s.replace("norm1.", "Residual_1.LayerNorm.")
s = s.replace("norm2.", "Residual_2.LayerNorm.")

s = s.replace("attn.qkv.", "Residual_1.WindowAttention.Linear_1.")
s = s.replace("attn.proj.", "Residual_1.WindowAttention.Linear_2.")
s = s.replace("attn.relative_position", "Residual_1.WindowAttention.WindowSDPA.rpb.relative_position")

s = s.replace("mlp.fc", "Residual_2.Linear_")
v = ".".join(
[
f"SwinTransformer.Chain_{int(m.group(1)) + 1}",
f"BasicLayer.SwinTransformerBlock_{int(m.group(2)) + 1}",
s,
]
)

if m := re.match(r"backbone\.norm(\d+)\.(.*)", v):
v = f"SwinTransformer.Chain_{int(m.group(1)) + 1}.Passthrough.LayerNorm.{m.group(2)}"

# MVANet

def mclm(s: str, pfx_src: str, pfx_dst: str) -> str:
pca = f"{pfx_dst}Residual.PatchwiseCrossAttention"
s = rpfx(s, f"{pfx_src}linear1.", f"{pfx_dst}FeedForward_1.Linear_1.")
s = rpfx(s, f"{pfx_src}linear2.", f"{pfx_dst}FeedForward_1.Linear_2.")
s = rpfx(s, f"{pfx_src}linear3.", f"{pfx_dst}FeedForward_2.Linear_1.")
s = rpfx(s, f"{pfx_src}linear4.", f"{pfx_dst}FeedForward_2.Linear_2.")
s = rpfx(s, f"{pfx_src}norm1.", f"{pfx_dst}LayerNorm_1.")
s = rpfx(s, f"{pfx_src}norm2.", f"{pfx_dst}LayerNorm_2.")
s = rpfx(s, f"{pfx_src}attention.0.", f"{pfx_dst}GlobalAttention.Sum.Chain.MultiheadAttention.")
s = rpfx(s, f"{pfx_src}attention.1.", f"{pca}.Concatenate.Chain_1.MultiheadAttention.")
s = rpfx(s, f"{pfx_src}attention.2.", f"{pca}.Concatenate.Chain_2.MultiheadAttention.")
s = rpfx(s, f"{pfx_src}attention.3.", f"{pca}.Concatenate.Chain_3.MultiheadAttention.")
s = rpfx(s, f"{pfx_src}attention.4.", f"{pca}.Concatenate.Chain_4.MultiheadAttention.")
return s

def mcrm(s: str, pfx_src: str, pfx_dst: str) -> str:
# Note: there are no linear{1,2}, see https://github.com/qianyu-dlut/MVANet/issues/3#issuecomment-2105650425
tca = f"{pfx_dst}Parallel_3.TiledCrossAttention"
pca = f"{tca}.Sum.Chain_2.PatchwiseCrossAttention"
s = rpfx(s, f"{pfx_src}linear3.", f"{tca}.FeedForward.Linear_1.")
s = rpfx(s, f"{pfx_src}linear4.", f"{tca}.FeedForward.Linear_2.")
s = rpfx(s, f"{pfx_src}norm1.", f"{tca}.LayerNorm_1.")
s = rpfx(s, f"{pfx_src}norm2.", f"{tca}.LayerNorm_2.")
s = rpfx(s, f"{pfx_src}attention.0.", f"{pca}.Concatenate.Chain_1.MultiheadAttention.")
s = rpfx(s, f"{pfx_src}attention.1.", f"{pca}.Concatenate.Chain_2.MultiheadAttention.")
s = rpfx(s, f"{pfx_src}attention.2.", f"{pca}.Concatenate.Chain_3.MultiheadAttention.")
s = rpfx(s, f"{pfx_src}attention.3.", f"{pca}.Concatenate.Chain_4.MultiheadAttention.")
s = rpfx(s, f"{pfx_src}sal_conv.", f"{pfx_dst}Parallel_2.Multiply.Chain.Conv2d.")
return s

def cbr(s: str, pfx_src: str, pfx_dst: str, shift: int = 0) -> str:
s = rpfx(s, f"{pfx_src}{shift}.", f"{pfx_dst}Conv2d.")
s = rpfx(s, f"{pfx_src}{shift + 1}.", f"{pfx_dst}BatchNorm2d.")
s = rpfx(s, f"{pfx_src}{shift + 2}.", f"{pfx_dst}PReLU.")
return s

def cbg(s: str, pfx_src: str, pfx_dst: str) -> str:
s = rpfx(s, f"{pfx_src}0.", f"{pfx_dst}Conv2d.")
s = rpfx(s, f"{pfx_src}1.", f"{pfx_dst}BatchNorm2d.")
return s

v = rpfx(v, "shallow.0.", "ComputeShallow.Conv2d.")

v = cbr(v, "output1.", "Pyramid.Sum.Chain.CBR.")
v = cbr(v, "output2.", "Pyramid.Sum.PyramidL2.Sum.Chain.CBR.")
v = cbr(v, "output3.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.Chain.CBR.")
v = cbr(v, "output4.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.Sum.Chain.CBR.")
v = cbr(v, "output5.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.Sum.PyramidL5.CBR.")

v = cbr(v, "conv1.", "Pyramid.CBR.")
v = cbr(v, "conv2.", "Pyramid.Sum.PyramidL2.CBR.")
v = cbr(v, "conv3.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.CBR.")
v = cbr(v, "conv4.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.CBR.")

v = mclm(v, "multifieldcrossatt.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.Sum.PyramidL5.MCLM.")

v = mcrm(v, "dec_blk1.", "Pyramid.MCRM.")
v = mcrm(v, "dec_blk2.", "Pyramid.Sum.PyramidL2.MCRM.")
v = mcrm(v, "dec_blk3.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.MCRM.")
v = mcrm(v, "dec_blk4.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.MCRM.")

v = cbr(v, "insmask_head.", "RearrangeMultiView.Chain.CBR_1.")
v = cbr(v, "insmask_head.", "RearrangeMultiView.Chain.CBR_2.", shift=3)

v = rpfx(v, "insmask_head.6.", "RearrangeMultiView.Chain.Conv2d.")

v = cbg(v, "upsample1.", "ShallowUpscaler.Sum_2.Chain_1.CBG.")
v = cbg(v, "upsample2.", "ShallowUpscaler.CBG.")

v = rpfx(v, "output.0.", "Conv2d.")

if v != k:
keys_map[k] = v

for key, new_key in keys_map.items():
state_dict[new_key] = state_dict[key]
state_dict.pop(key)

return state_dict
Loading

0 comments on commit 10dfa73

Please sign in to comment.