Skip to content

Commit

Permalink
lazily create up_descending after state dict is already loaded, but o…
Browse files Browse the repository at this point in the history
…nly do it once
  • Loading branch information
technillogue committed Oct 15, 2024
1 parent f910e01 commit 92c664c
Showing 1 changed file with 25 additions and 22 deletions.
47 changes: 25 additions & 22 deletions flux/modules/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,13 @@ def forward(self, x: Tensor):
return x

class DownBlock(nn.Module):
def __init__(self, block, downsample=None):
def __init__(self, block: list, downsample: nn.Module) -> None:
super().__init__()
# we're doing this instead of a flat nn.Sequential to preserve the keys "block" "downsample"
self.block = nn.Sequential(*block)
# attn
if downsample:
self.downsample = downsample
else:
self.downsample = nn.Identity()
self.downsample = downsample

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.downsample(self.block(x))


Expand Down Expand Up @@ -143,8 +140,9 @@ def __init__(
self.in_ch_mult = in_ch_mult
down_layers = []
block_in = self.ch
# ideally, this would all append to a single flat nn.Sequential
# we cannot do this due to the existing state dict keys
for i_level in range(self.num_resolutions):
# attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
block_layers = []
Expand All @@ -156,7 +154,7 @@ def __init__(
downsample = Downsample(block_in)
curr_res = curr_res // 2
else:
downsample = None
downsample = nn.Identity()
down_layers.append(DownBlock(block_layers, downsample))
self.down = nn.Sequential(*down_layers)

Expand Down Expand Up @@ -186,17 +184,13 @@ def forward(self, x: Tensor) -> Tensor:
return h

class UpBlock(nn.Module):
def __init__(self, block, upsample=None):
def __init__(self, block: list, upsample: nn.Module) -> None:
super().__init__()
self.block = nn.Sequential(*block)
if upsample:
self.upsample = upsample
else:
self.upsample = nn.Identity()
self.upsample = upsample

def forward(self, x):
x = self.block(x)
return self.upsample(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.upsample(self.block(x))


class Decoder(nn.Module):
Expand Down Expand Up @@ -234,7 +228,7 @@ def __init__(

# upsampling
up_blocks = []
# 3, 2, 1, 0
# 3, 2, 1, 0, descending order
for i_level in reversed(range(self.num_resolutions)):
level_blocks = []
block_out = ch * ch_mult[i_level]
Expand All @@ -245,16 +239,25 @@ def __init__(
upsample = Upsample(block_in)
curr_res = curr_res * 2
else:
upsample = None
# ??? gross
# 0, 1, 2, 3
upsample = nn.Identity()
# 0, 1, 2, 3, ascending order
up_blocks.insert(0, UpBlock(level_blocks, upsample)) # prepend to get consistent order
self.up = nn.Sequential(*up_blocks)

# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)

# this is a hack to get something like property but only evaluate it once
# we're doing it like this so that up_descending isn't in the state_dict keys
# without adding anything conditional to the main flow
def __getattr__(self, name):
if name == "up_descending":
self.up_descending = nn.Sequential(*reversed(self.up))
Decoder.__getattr__ = nn.Module.__getattr__
return self.up_descending
return super().__getattr__(name)

def forward(self, z: Tensor) -> Tensor:
# z to block_in
h = self.conv_in(z)
Expand All @@ -265,7 +268,7 @@ def forward(self, z: Tensor) -> Tensor:
h = self.mid.block_2(h)

# upsampling
h = self.up(h)
h = self.up_descending(h)

# end
h = self.norm_out(h)
Expand Down

0 comments on commit 92c664c

Please sign in to comment.