Skip to content

Adding DPT #1079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
78ba0e8
Initial timm vit encoder commit
vedantdalimkar Feb 28, 2025
2c38de6
Add DPT model and update logic for TimmViTEncoder class
vedantdalimkar Mar 2, 2025
5599409
Removed redudant documentation
vedantdalimkar Mar 2, 2025
c47bdfb
Added intitial test and some minor code modifications
vedantdalimkar Mar 5, 2025
71e2acb
Code refactor
vedantdalimkar Mar 8, 2025
e85836d
Added weight conversion script
vedantdalimkar Mar 22, 2025
35cb060
Moved conversion script to appropriate location
vedantdalimkar Mar 22, 2025
aa84f4e
Added logic in timm table generation for adding ViT encoders for DPT
Mar 22, 2025
67c4a75
Ruff formatting
vedantdalimkar Mar 22, 2025
85f22fb
Code revision
vedantdalimkar Mar 26, 2025
ef48032
Remove unnecessary comment
vedantdalimkar Mar 27, 2025
28204ad
Simplify ViT encoder
qubvel Apr 5, 2025
1b9a6f6
Refactor ProjectionReadout
qubvel Apr 5, 2025
334cfbb
Refactor modeling DPT
qubvel Apr 6, 2025
7e1ef3b
Support more encoders
qubvel Apr 6, 2025
d65c0f7
Refactor a bit conversion, added validation
qubvel Apr 6, 2025
0a62fe0
Fixup
qubvel Apr 6, 2025
e3238ae
Split forward for timm_vit
qubvel Apr 6, 2025
df4d087
Rename readout, remove feature_dim
qubvel Apr 6, 2025
8bcb0ed
refactor + add transform
qubvel Apr 6, 2025
6ba6746
Fixup
qubvel Apr 6, 2025
8fd8c77
Refine docs a bit
qubvel Apr 6, 2025
9bf1fd2
Refine docs
qubvel Apr 6, 2025
0e9170f
Refine model size a bit and docs
qubvel Apr 6, 2025
a0aa5a8
Add to docs
qubvel Apr 6, 2025
6cfd3be
Add note
qubvel Apr 6, 2025
d4b162d
Remove txt
qubvel Apr 6, 2025
5fe80a5
Fix doc
qubvel Apr 6, 2025
0a14972
Fix docstring
qubvel Apr 6, 2025
5b28978
Fixing list in activation
qubvel Apr 6, 2025
0ed621c
Fixing list
qubvel Apr 6, 2025
6207310
Fixing list
qubvel Apr 6, 2025
19eeebe
Fixup, fix type hint
qubvel Apr 6, 2025
f2e3f89
Merge branch 'main' into pr/vedantdalimkar/1079
qubvel Apr 6, 2025
1257c4b
Add to README
qubvel Apr 6, 2025
21a164a
Add example
qubvel Apr 6, 2025
8d3ed4f
Add decoder_readout according to initial impl
qubvel Apr 7, 2025
4eb6ec3
Tests update
vedantdalimkar Apr 7, 2025
165b9c0
Fix encoder tests
qubvel Apr 7, 2025
5603707
Fix DPT tests
qubvel Apr 7, 2025
9518964
Refactor a bit
qubvel Apr 7, 2025
38cb944
Tests
qubvel Apr 7, 2025
17d3328
Update gen test models
qubvel Apr 7, 2025
83b9655
Revert gitignore
qubvel Apr 7, 2025
343fbe0
Fix test
qubvel Apr 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Segmentation based on [PyTorch](https://pytorch.org/).**
The main features of the library are:

- Super simple high-level API (just two lines to create a neural network)
- 11 encoder-decoder model architectures (Unet, Unet++, Segformer, ...)
- 12 encoder-decoder model architectures (Unet, Unet++, Segformer, DPT, ...)
- 800+ **pretrained** convolution- and transform-based encoders, including [timm](https://github.com/huggingface/pytorch-image-models) support
- Popular metrics and losses for training routines (Dice, Jaccard, Tversky, ...)
- ONNX export and torch script/trace/compile friendly
Expand Down Expand Up @@ -105,6 +105,7 @@ Congratulations! You are done! Now you can train your model with your favorite f
| **Train** multiclass segmentation on CamVid | [Notebook](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/examples/camvid_segmentation_multiclass.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel-org/segmentation_models.pytorch/blob/main/examples/camvid_segmentation_multiclass.ipynb) |
| **Train** clothes binary segmentation by @ternaus | [Repo](https://github.com/ternaus/cloths_segmentation) | |
| **Load and inference** pretrained Segformer | [Notebook](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/examples/segformer_inference_pretrained.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/segformer_inference_pretrained.ipynb) |
| **Load and inference** pretrained DPT | [Notebook](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/examples/dpt_inference_pretrained.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/dpt_inference_pretrained.ipynb) |
| **Save and load** models locally / to HuggingFace Hub |[Notebook](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/examples/save_load_model_and_share_with_hf_hub.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/save_load_model_and_share_with_hf_hub.ipynb)
| **Export** trained model to ONNX | [Notebook](https://github.com/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb) |

Expand All @@ -123,6 +124,7 @@ Congratulations! You are done! Now you can train your model with your favorite f
- DeepLabV3+ [[paper](https://arxiv.org/abs/1802.02611)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id9)]
- UPerNet [[paper](https://arxiv.org/abs/1807.10221)] [[docs](https://smp.readthedocs.io/en/latest/models.html#upernet)]
- Segformer [[paper](https://arxiv.org/abs/2105.15203)] [[docs](https://smp.readthedocs.io/en/latest/models.html#segformer)]
- DPT [[paper](https://arxiv.org/abs/2103.13413)] [[docs](https://smp.readthedocs.io/en/latest/models.html#dpt)]

### Encoders <a name="encoders"></a>

Expand Down
461 changes: 461 additions & 0 deletions docs/encoders_dpt.rst

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions docs/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,18 @@ Segformer
~~~~~~~~~
.. autoclass:: segmentation_models_pytorch.Segformer


.. _dpt:

DPT
~~~

.. note::

See full list of DPT-compatible timm encoders in :ref:`dpt-encoders`.

.. note::

For some encoders, the model requires ``dynamic_img_size=True`` to be passed in order to work with resolutions different from what the encoder was trained for.

.. autoclass:: segmentation_models_pytorch.DPT
138 changes: 138 additions & 0 deletions examples/dpt_inference_pretrained.ipynb

Large diffs are not rendered by default.

59 changes: 51 additions & 8 deletions misc/generate_table_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,68 @@ def has_dilation_support(name):
return False


def valid_vit_encoder_for_dpt(name):
if "vit" not in name:
return False
encoder = timm.create_model(name)
feature_info = encoder.feature_info
feature_info_obj = timm.models.FeatureInfo(
feature_info=feature_info, out_indices=[0, 1, 2, 3]
)
reduction_scales = list(feature_info_obj.reduction())

if len(set(reduction_scales)) > 1:
return False

output_stride = reduction_scales[0]
if bin(output_stride).count("1") != 1:
return False

return True


def make_table(data):
names = data.keys()
max_len1 = max([len(x) for x in names]) + 2
max_len2 = len("support dilation") + 2
max_len3 = len("Supported for DPT") + 2

l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+\n"
l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+\n"
l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+" + "-" * max_len3 + "+\n"
l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+" + "-" * max_len3 + "+\n"
top = (
"| "
+ "Encoder name".ljust(max_len1 - 2)
+ " | "
+ "Support dilation".center(max_len2 - 2)
+ " | "
+ "Supported for DPT".center(max_len3 - 2)
+ " |\n"
)

table = l1 + top + l2

for k in sorted(data.keys()):
support = (
"✅".center(max_len2 - 3)
if data[k]["has_dilation"]
else " ".center(max_len2 - 2)
if "has_dilation" in data[k] and data[k]["has_dilation"]:
support = "✅".center(max_len2 - 3)

else:
support = " ".center(max_len2 - 2)

if "supported_only_for_dpt" in data[k]:
supported_for_dpt = "✅".center(max_len3 - 3)

else:
supported_for_dpt = " ".center(max_len3 - 2)

table += (
"| "
+ k.ljust(max_len1 - 2)
+ " | "
+ support
+ " | "
+ supported_for_dpt
+ " |\n"
)
table += "| " + k.ljust(max_len1 - 2) + " | " + support + " |\n"
table += l1

return table
Expand All @@ -55,8 +93,13 @@ def make_table(data):
check_features_and_reduction(name)
has_dilation = has_dilation_support(name)
supported_models[name] = dict(has_dilation=has_dilation)

except Exception:
continue
try:
if valid_vit_encoder_for_dpt(name):
supported_models[name] = dict(supported_only_for_dpt=True)
except Exception:
continue
Comment on lines +96 to +102
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check only if we got an exception here?
Would it be better to make two independent checks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you check the behaviour of functions check_features_and_reduction and valid_vit_encoder_for_dpt, their output is mutually exclusive. To be more detailed:

  1. check_features_and_reduction returns true only when reduction scales of a model are equal to [2, 4, 8, 16, 32], whereas,
  2. valid_vit_encoder_for_dpt returns false if the encoder has multiple reduction scales.

In short, a model which satisfies the conditions specified by check_features_and_reduction will never satisfy the conditions set by valid_vit_encoder_for_dpt and vice versa.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I suppose this code should be updated as well, because as far as I remember [4, 8, 16, 32] and [1, 2, 4, 8, 16, 32] reductions are also supported

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I suppose this code should be updated as well, because as far as I remember [4, 8, 16, 32] and [1, 2, 4, 8, 16, 32] reductions are also supported

Should I update this as well or will you do it from your end?


table = make_table(supported_models)
print(table)
Expand Down
45 changes: 31 additions & 14 deletions misc/generate_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,50 @@

api = huggingface_hub.HfApi(token=os.getenv("HF_TOKEN"))

for model_name, model_class in smp.MODEL_ARCHITECTURES_MAPPING.items():
model = model_class(encoder_name=ENCODER_NAME)
model = model.eval()

# generate test sample
torch.manual_seed(423553)
sample = torch.rand(1, 3, 256, 256)

with torch.no_grad():
output = model(sample)

def save_and_push(model, inputs, outputs, model_name, encoder_name):
with tempfile.TemporaryDirectory() as tmpdir:
# save model
model.save_pretrained(f"{tmpdir}")

# save input and output
torch.save(sample, f"{tmpdir}/input-tensor.pth")
torch.save(output, f"{tmpdir}/output-tensor.pth")
torch.save(inputs, f"{tmpdir}/input-tensor.pth")
torch.save(outputs, f"{tmpdir}/output-tensor.pth")

# create repo
repo_id = f"{HUB_REPO}/{model_name}-{ENCODER_NAME}"
repo_id = f"{HUB_REPO}/{model_name}-{encoder_name}"
if not api.repo_exists(repo_id=repo_id):
api.create_repo(repo_id=repo_id, repo_type="model")

# upload to hub
api.upload_folder(
folder_path=tmpdir,
repo_id=f"{HUB_REPO}/{model_name}-{ENCODER_NAME}",
repo_id=f"{HUB_REPO}/{model_name}-{encoder_name}",
repo_type="model",
)


for model_name, model_class in smp.MODEL_ARCHITECTURES_MAPPING.items():
if model_name == "dpt":
encoder_name = "tu-test_vit"
model = smp.DPT(
encoder_name=encoder_name,
decoder_readout="cat",
decoder_intermediate_channels=(16, 32, 64, 64),
decoder_fusion_channels=16,
dynamic_img_size=True,
)
else:
encoder_name = ENCODER_NAME
model = model_class(encoder_name=encoder_name)

model = model.eval()

# generate test sample
torch.manual_seed(423553)
sample = torch.rand(1, 3, 256, 256)

with torch.no_grad():
output = model(sample)

save_and_push(model, sample, output, model_name, encoder_name)
122 changes: 122 additions & 0 deletions scripts/models-conversions/dpt-original-to-smp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import cv2
import torch
import albumentations as A
import segmentation_models_pytorch as smp

MODEL_WEIGHTS_PATH = r"dpt_large-ade20k-b12dca68.pt"
HF_HUB_PATH = "qubvel-hf/dpt-large-ade20k"
PUSH_TO_HUB = False


def get_transform():
return A.Compose(
[
A.LongestMaxSize(max_size=480, interpolation=cv2.INTER_CUBIC),
A.Normalize(
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), max_pixel_value=255.0
),
# This is not correct transform, ideally image should resized without padding to multiple of 32,
# but we take there is no such transform in albumentations, here is closest one
A.PadIfNeeded(
min_height=None,
min_width=None,
pad_height_divisor=32,
pad_width_divisor=32,
border_mode=cv2.BORDER_CONSTANT,
value=0,
p=1,
),
]
)


if __name__ == "__main__":
# fmt: off
smp_model = smp.DPT(encoder_name="tu-vit_large_patch16_384", classes=150, dynamic_img_size=True)
dpt_model_dict = torch.load(MODEL_WEIGHTS_PATH, weights_only=True)

for layer_index in range(0, 4):
for param in ["running_mean", "running_var", "num_batches_tracked", "weight", "bias"]:
for block_index in [1, 2]:
for bn_index in [1, 2]:
# Assigning weights of 4th fusion layer of original model to 1st layer of SMP DPT model,
# Assigning weights of 3rd fusion layer of original model to 2nd layer of SMP DPT model ...
# and so on ...
# This is because order of calling fusion layers is reversed in original DPT implementation
dpt_model_dict[f"decoder.fusion_blocks.{layer_index}.residual_conv_block{block_index}.batch_norm_{bn_index}.{param}"] = \
dpt_model_dict.pop(f"scratch.refinenet{4 - layer_index}.resConfUnit{block_index}.bn{bn_index}.{param}")

if param in ["weight", "bias"]:
if param == "weight":
for block_index in [1, 2]:
for conv_index in [1, 2]:
dpt_model_dict[f"decoder.fusion_blocks.{layer_index}.residual_conv_block{block_index}.conv_{conv_index}.{param}"] = \
dpt_model_dict.pop(f"scratch.refinenet{4 - layer_index}.resConfUnit{block_index}.conv{conv_index}.{param}")

dpt_model_dict[f"decoder.reassemble_blocks.{layer_index}.project_to_feature_dim.{param}"] = \
dpt_model_dict.pop(f"scratch.layer{layer_index + 1}_rn.{param}")

dpt_model_dict[f"decoder.fusion_blocks.{layer_index}.project.{param}"] = \
dpt_model_dict.pop(f"scratch.refinenet{4 - layer_index}.out_conv.{param}")

dpt_model_dict[f"decoder.projection_blocks.{layer_index}.project.0.{param}"] = \
dpt_model_dict.pop(f"pretrained.act_postprocess{layer_index + 1}.0.project.0.{param}")

dpt_model_dict[f"decoder.reassemble_blocks.{layer_index}.project_to_out_channel.{param}"] = \
dpt_model_dict.pop(f"pretrained.act_postprocess{layer_index + 1}.3.{param}")

if layer_index != 2:
dpt_model_dict[f"decoder.reassemble_blocks.{layer_index}.upsample.{param}"] = \
dpt_model_dict.pop(f"pretrained.act_postprocess{layer_index + 1}.4.{param}")

# Changing state dict keys for segmentation head
dpt_model_dict = {
name.replace("scratch.output_conv", "segmentation_head.head"): parameter
for name, parameter in dpt_model_dict.items()
}

# Changing state dict keys for encoder layers
dpt_model_dict = {
name.replace("pretrained.model", "encoder.model"): parameter
for name, parameter in dpt_model_dict.items()
}

# Removing keys, value pairs associated with auxiliary head
dpt_model_dict = {
name: parameter
for name, parameter in dpt_model_dict.items()
if not name.startswith("auxlayer")
}
# fmt: on

smp_model.load_state_dict(dpt_model_dict, strict=True)

# ------- DO NOT touch this section -------
smp_model.eval()

input_tensor = torch.ones((1, 3, 384, 384))
output = smp_model(input_tensor)

print(output.shape)
print(output[0, 0, :3, :3])

expected_slice = torch.tensor(
[
[3.4243, 3.4553, 3.4863],
[3.3332, 3.2876, 3.2419],
[3.2422, 3.1199, 2.9975],
]
)

torch.testing.assert_close(
output[0, 0, :3, :3], expected_slice, atol=1e-4, rtol=1e-4
)

# Saving
transform = get_transform()

transform.save_pretrained(HF_HUB_PATH)
smp_model.save_pretrained(HF_HUB_PATH, push_to_hub=PUSH_TO_HUB)

# Re-loading to make sure everything is saved correctly
smp_model = smp.from_pretrained(HF_HUB_PATH)
3 changes: 3 additions & 0 deletions segmentation_models_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .decoders.pan import PAN
from .decoders.upernet import UPerNet
from .decoders.segformer import Segformer
from .decoders.dpt import DPT
from .base.hub_mixin import from_pretrained

from .__version__ import __version__
Expand All @@ -34,6 +35,7 @@
PAN,
UPerNet,
Segformer,
DPT,
]
MODEL_ARCHITECTURES_MAPPING = {a.__name__.lower(): a for a in _MODEL_ARCHITECTURES}

Expand Down Expand Up @@ -84,6 +86,7 @@ def create_model(
"PAN",
"UPerNet",
"Segformer",
"DPT",
"from_pretrained",
"create_model",
"__version__",
Expand Down
6 changes: 2 additions & 4 deletions segmentation_models_pytorch/decoders/deeplabv3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ class DeepLabV3(SegmentationModel):
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
**callable** and **None**.
Default is **None**
**callable** and **None**. Default is **None**.
upsampling: Final upsampling factor. Default is **None** to preserve input-output spatial shape identity
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
on top of encoder if **aux_params** is not **None** (default). Supported params:
Expand Down Expand Up @@ -159,8 +158,7 @@ class DeepLabV3Plus(SegmentationModel):
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
**callable** and **None**.
Default is **None**
**callable** and **None**. Default is **None**.
upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity.
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
on top of encoder if **aux_params** is not **None** (default). Supported params:
Expand Down
3 changes: 3 additions & 0 deletions segmentation_models_pytorch/decoders/dpt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .model import DPT

__all__ = ["DPT"]
Loading
Loading