Skip to content

Commit

Permalink
ok
Browse files Browse the repository at this point in the history
  • Loading branch information
johndpope committed Oct 4, 2024
1 parent 3c71c75 commit 7940ef1
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
11 changes: 2 additions & 9 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,16 +478,9 @@ def process_tokens(self, t_c, t_r):
return m_c, m_r


class IMFEncoder(nn.Module):
def __init__(self, model):
super(IMFEncoder, self).__init__()
self.model = model

def forward(self, x_current, x_reference):
f_r = self.model.dense_feature_encoder(x_reference)
t_r = self.model.latent_token_encoder(x_reference)
t_c = self.model.latent_token_encoder(x_current)
return f_r, t_r, t_c



class MappingNetwork(nn.Module):
def __init__(self, latent_dim, w_dim, depth):
Expand Down
11 changes: 11 additions & 0 deletions onnxconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@
from PIL import Image
from torchvision import transforms


class IMFDecoder(nn.Module):
def __init__(self, model):
super(IMFDecoder, self).__init__()
self.model = model

def decode_latent_tokens(self,f_r,t_r,t_c):
return self.model.decode_latent_tokens(f_r,t_r,t_c)



# Define the IMFEncoder class
class IMFEncoder(nn.Module):
def __init__(self, model):
Expand Down

0 comments on commit 7940ef1

Please sign in to comment.