-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ObjectFormer #41
Comments
@Inkyl Xuekang is responsible for this part. |
Sorry, the specific weights are not currently available, but I can provide you with a script to process and extract the relevant weights. import math
from typing import List, Optional
import torch
import timm
import torch.nn.functional as F
# Load a pre-trained Vision Transformer (ViT) model
model = timm.create_model('vit_base_patch16_224', pretrained=True)
def resample_abs_pos_embed(
posemb,
new_size: List[int],
old_size: Optional[List[int]] = None,
num_prefix_tokens: int = 1,
interpolation: str = 'bicubic',
antialias: bool = True,
verbose: bool = False,
):
# Determine the old and new sizes, assuming a square shape if old_size is not provided
num_pos_tokens = posemb.shape[1]
num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens
if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
return posemb
if old_size is None:
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
old_size = hw, hw
# Separate the prefix tokens if any exist
if num_prefix_tokens:
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
else:
posemb_prefix, posemb = None, posemb
# Perform interpolation
embed_dim = posemb.shape[-1]
orig_dtype = posemb.dtype
posemb = posemb.float() # Convert to float32 for interpolation
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
posemb = posemb.to(orig_dtype)
# Concatenate back the prefix tokens if they were separated earlier
if posemb_prefix is not None:
posemb = torch.cat([posemb_prefix, posemb], dim=1)
return posemb
# Initialize a dictionary to store the processed weights
processed_state_dict = {}
# Extract and resample the positional embedding
pos_embed = model.state_dict()['pos_embed'][0][1::].unsqueeze(0)
pos_embed = resample_abs_pos_embed(pos_embed, [14, 28], num_prefix_tokens=0)
processed_state_dict['pos_embed'] = pos_embed
# Copy the patch embedding projection weights
processed_state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
processed_state_dict['patch_embed.proj.bias'] = model.state_dict()['patch_embed.proj.bias']
# Process and extract weights from the first 8 transformer blocks
for i in range(8): # Only process the first 8 blocks
block_prefix = f'blocks.{i}.'
# Extract norm1 weights and biases
processed_state_dict[f'{block_prefix}norm1.weight'] = model.state_dict()[f'{block_prefix}norm1.weight']
processed_state_dict[f'{block_prefix}norm1.bias'] = model.state_dict()[f'{block_prefix}norm1.bias']
# Split and extract q, k, v weights and biases from qkv
qkv_weight = model.state_dict()[f'{block_prefix}attn.qkv.weight']
qkv_bias = model.state_dict()[f'{block_prefix}attn.qkv.bias']
dim = qkv_weight.shape[0] // 3
processed_state_dict[f'{block_prefix}attn.q.weight'] = qkv_weight[:dim]
processed_state_dict[f'{block_prefix}attn.k.weight'] = qkv_weight[dim:2*dim]
processed_state_dict[f'{block_prefix}attn.v.weight'] = qkv_weight[2*dim:]
processed_state_dict[f'{block_prefix}attn.q.bias'] = qkv_bias[:dim]
processed_state_dict[f'{block_prefix}attn.k.bias'] = qkv_bias[dim:2*dim]
processed_state_dict[f'{block_prefix}attn.v.bias'] = qkv_bias[2*dim:]
# Extract the attention projection weights and biases
processed_state_dict[f'{block_prefix}attn.proj.weight'] = model.state_dict()[f'{block_prefix}attn.proj.weight']
processed_state_dict[f'{block_prefix}attn.proj.bias'] = model.state_dict()[f'{block_prefix}attn.proj.bias']
# Extract norm2 and MLP weights and biases
processed_state_dict[f'{block_prefix}norm2.weight'] = model.state_dict()[f'{block_prefix}norm2.weight']
processed_state_dict[f'{block_prefix}norm2.bias'] = model.state_dict()[f'{block_prefix}norm2.bias']
processed_state_dict[f'{block_prefix}mlp.fc1.weight'] = model.state_dict()[f'{block_prefix}mlp.fc1.weight']
processed_state_dict[f'{block_prefix}mlp.fc1.bias'] = model.state_dict()[f'{block_prefix}mlp.fc1.bias']
processed_state_dict[f'{block_prefix}mlp.fc2.weight'] = model.state_dict()[f'{block_prefix}mlp.fc2.weight']
processed_state_dict[f'{block_prefix}mlp.fc2.bias'] = model.state_dict()[f'{block_prefix}mlp.fc2.bias']
# Save the processed weights to a .pth file
torch.save(processed_state_dict, 'processed_model_weights.pth') |
Hi, i find some parts of the ObjectFormer are missing. IMDLBenCo/IMDLBenCo/model_zoo/object_former/object_former.py Lines 319 to 324 in 1807684
The implementation of label loss is missing. This method needs a global average pooling with a linear layer as the classifier |
您好,请问ObjectFormer的--init_weight_path object_former/processed_model_weights.pth 这个weight在哪里
谢谢回复
The text was updated successfully, but these errors were encountered: