diff --git a/flux/modules/autoencoder.py b/flux/modules/autoencoder.py index 23214b9..dc9a08c 100644 --- a/flux/modules/autoencoder.py +++ b/flux/modules/autoencoder.py @@ -106,16 +106,13 @@ def forward(self, x: Tensor): return x class DownBlock(nn.Module): - def __init__(self, block, downsample=None): + def __init__(self, block: list, downsample: nn.Module) -> None: super().__init__() + # we're doing this instead of a flat nn.Sequential to preserve the keys "block" "downsample" self.block = nn.Sequential(*block) - # attn - if downsample: - self.downsample = downsample - else: - self.downsample = nn.Identity() + self.downsample = downsample - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.downsample(self.block(x)) @@ -143,8 +140,9 @@ def __init__( self.in_ch_mult = in_ch_mult down_layers = [] block_in = self.ch + # ideally, this would all append to a single flat nn.Sequential + # we cannot do this due to the existing state dict keys for i_level in range(self.num_resolutions): - # attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] block_layers = [] @@ -156,7 +154,7 @@ def __init__( downsample = Downsample(block_in) curr_res = curr_res // 2 else: - downsample = None + downsample = nn.Identity() down_layers.append(DownBlock(block_layers, downsample)) self.down = nn.Sequential(*down_layers) @@ -186,17 +184,13 @@ def forward(self, x: Tensor) -> Tensor: return h class UpBlock(nn.Module): - def __init__(self, block, upsample=None): + def __init__(self, block: list, upsample: nn.Module) -> None: super().__init__() self.block = nn.Sequential(*block) - if upsample: - self.upsample = upsample - else: - self.upsample = nn.Identity() + self.upsample = upsample - def forward(self, x): - x = self.block(x) - return self.upsample(x) + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.upsample(self.block(x)) class Decoder(nn.Module): @@ -234,7 +228,7 @@ def __init__( # upsampling up_blocks = [] - # 3, 2, 1, 0 + # 3, 2, 1, 0, descending order for i_level in reversed(range(self.num_resolutions)): level_blocks = [] block_out = ch * ch_mult[i_level] @@ -245,9 +239,8 @@ def __init__( upsample = Upsample(block_in) curr_res = curr_res * 2 else: - upsample = None - # ??? gross - # 0, 1, 2, 3 + upsample = nn.Identity() + # 0, 1, 2, 3, ascending order up_blocks.insert(0, UpBlock(level_blocks, upsample)) # prepend to get consistent order self.up = nn.Sequential(*up_blocks) @@ -255,6 +248,16 @@ def __init__( self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + # this is a hack to get something like property but only evaluate it once + # we're doing it like this so that up_descending isn't in the state_dict keys + # without adding anything conditional to the main flow + def __getattr__(self, name): + if name == "up_descending": + self.up_descending = nn.Sequential(*reversed(self.up)) + Decoder.__getattr__ = nn.Module.__getattr__ + return self.up_descending + return super().__getattr__(name) + def forward(self, z: Tensor) -> Tensor: # z to block_in h = self.conv_in(z) @@ -265,7 +268,7 @@ def forward(self, z: Tensor) -> Tensor: h = self.mid.block_2(h) # upsampling - h = self.up(h) + h = self.up_descending(h) # end h = self.norm_out(h)