Skip to content

Commit

Permalink
ok
Browse files Browse the repository at this point in the history
  • Loading branch information
johndpope committed Oct 4, 2024
1 parent c6eb4da commit 3c71c75
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 22 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ recon_epoch_1.png
*.png
__pycache__/vit.cpython-311.pyc
__pycache__/helper.cpython-311.pyc
actions-runner/*
actions-runner/*
imf_encoder.onnx
12 changes: 12 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,18 @@ def process_tokens(self, t_c, t_r):

return m_c, m_r


class IMFEncoder(nn.Module):
def __init__(self, model):
super(IMFEncoder, self).__init__()
self.model = model

def forward(self, x_current, x_reference):
f_r = self.model.dense_feature_encoder(x_reference)
t_r = self.model.latent_token_encoder(x_reference)
t_c = self.model.latent_token_encoder(x_current)
return f_r, t_r, t_c

class MappingNetwork(nn.Module):
def __init__(self, latent_dim, w_dim, depth):
super().__init__()
Expand Down
67 changes: 46 additions & 21 deletions onnxconv.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
import torch
import torch.onnx
from IMF.model import IMFModel



from model import IMFModel
import torch.nn as nn
from PIL import Image
from torchvision import transforms

# Define the IMFEncoder class
class IMFEncoder(nn.Module):
def __init__(self, model):
super(IMFEncoder, self).__init__()
self.model = model

def forward(self, x_current, x_reference):
f_r = self.model.dense_feature_encoder(x_reference)
t_r = self.model.latent_token_encoder(x_reference)
t_c = self.model.latent_token_encoder(x_current)
return f_r, t_r, t_c # Fixed indentation here

# Define the trace handler and utility functions
def trace_handler(module, input, output):
print(f"\nModule: {module.__class__.__name__}")
for idx, inp in enumerate(input):
Expand All @@ -14,7 +28,6 @@ def trace_handler(module, input, output):
for idx, out in enumerate(output):
print_tensor_info(out, f" Output[{idx}]")


def print_tensor_info(tensor, name, indent=0):
indent_str = ' ' * indent
print(f"{indent_str}{name}:")
Expand All @@ -39,6 +52,7 @@ def print_model_structure(model):
if hasattr(module, 'bias') and module.bias is not None:
print(f" Bias shape: {module.bias.shape}")

# Adjusted export_to_onnx function
def export_to_onnx(model, x_current, x_reference, file_name):
try:
print("Model structure before tracing:")
Expand Down Expand Up @@ -67,21 +81,21 @@ def export_to_onnx(model, x_current, x_reference, file_name):
torch.onnx.export(
model,
(x_current, x_reference),
"imf_model.onnx",
file_name,
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['x_current', 'x_reference'],
output_names=['output'],
output_names=['f_r', 't_r', 't_c'], # Adjusted output names
dynamic_axes={
'x_current': {0: 'batch_size'},
'x_reference': {0: 'batch_size'},
'output': {0: 'batch_size'}
'f_r': {0: 'batch_size'},
't_r': {0: 'batch_size'},
't_c': {0: 'batch_size'}
},
verbose=True
)


print(f"Model exported successfully to {file_name}")
except Exception as e:
print(f"Error during ONNX export: {str(e)}")
Expand All @@ -93,16 +107,27 @@ def export_to_onnx(model, x_current, x_reference, file_name):
model.eval()

# Load the checkpoint
checkpoint = torch.load("./checkpoints/checkpoint.pth", map_location=lambda storage, loc: storage)
state_dict = checkpoint['model_state_dict']

# # Adjust the weights in the state_dict
# for key in state_dict.keys():
# if 'csonv.weight' in key and state_dict[key].dim() == 5:
# state_dict[key] = state_dict[key].squeeze(0)
# Create dummy input tensors
x_current = torch.randn(1, 3, 256, 256)
x_reference = torch.randn(1, 3, 256, 256)
checkpoint = torch.load("./checkpoints/checkpoint.pth", map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])

# Create the IMFEncoder instance
encoder_model = IMFEncoder(model)
encoder_model.eval()

# Load images and preprocess
def load_image(image_path):
transform = transforms.Compose([
transforms.Resize((256, 256)), # Adjust as per your model's requirements
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], # Adjust as per your model's requirements
std=[0.229, 0.224, 0.225])
])
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0) # Add batch dimension
return image

x_current = load_image("x_current.png")
x_reference = load_image("x_reference.png")

# Export the model
export_to_onnx(model, x_current, x_reference, "imf_model.onnx")
export_to_onnx(encoder_model, x_current, x_reference, "imf_encoder.onnx")

0 comments on commit 3c71c75

Please sign in to comment.