From 2234c4149cee8674426e905534f66a24a2a4738a Mon Sep 17 00:00:00 2001 From: cabralpinto Date: Sat, 26 Aug 2023 15:32:56 +0100 Subject: [PATCH] Remove EfficientUNet, change params to match docs --- diffusion/net.py | 144 +++-------------------------------------------- 1 file changed, 7 insertions(+), 137 deletions(-) diff --git a/diffusion/net.py b/diffusion/net.py index 17f90bf..4618780 100644 --- a/diffusion/net.py +++ b/diffusion/net.py @@ -1,4 +1,4 @@ -from typing import Sequence, Optional +from typing import Sequence import torch from einops.layers.torch import Rearrange @@ -61,7 +61,7 @@ def __init__( self, channels: Sequence[int], labels: int = 0, - ratio: int = 1, + parameters: int = 1, hidden: int = 256, heads: int = 8, groups: int = 16, @@ -102,8 +102,8 @@ def __init__( ]) for channels_ in zip(channels[:1:-1], channels[-2::-1]) ]) self.output = Sequential( - Conv2d(channels[1], ratio * channels[0], 3, 1, 1), - Rearrange("b (r c) h w -> r b c h w", r=ratio), + Conv2d(channels[1], parameters * channels[0], 3, 1, 1), + Rearrange("b (p c) h w -> p b c h w", p=parameters), ) def forward(self, x: Tensor, y: Tensor, t: Tensor) -> Tensor: @@ -131,136 +131,6 @@ def forward(self, x: Tensor, y: Tensor, t: Tensor) -> Tensor: return x -class EfficientUNet(Net): - """Implementation of Efficient U-Net from https://arxiv.org/abs/2205.11487v1""" - - class Block(Module): - - def __init__(self, channels: int, groups: int) -> None: - super().__init__() - self.conv = Conv2d(channels, channels, 1) - self.blocks = Sequential( - GroupNorm(groups, channels), - Swish(), - Conv2d(channels, channels, 3, 1, 1), - GroupNorm(groups, channels), - Swish(), - Conv2d(channels, channels, 3, 1, 1), - ) - - def forward(self, x: Tensor) -> Tensor: - return self.conv(x) + self.blocks(x) - - class Condition(Module): - - def __init__(self, labels: int, channels: int, hidden: int) -> None: - super().__init__() - self.label = Embedding(labels + 1, channels) - self.time = SinusoidalPositionalEmbedding(channels) - self.linear = Linear(hidden, 2 * channels) - - class DBlock(Module): - - def __init__( - self, - channels: tuple[int, int], - groups: int, - blocks: int, - heads: int, - stride: int = 1, - ) -> None: - super().__init__() - self.conv = Conv2d(channels[0], channels[1], 3, stride, 1) - self.condition = EfficientUNet.Condition(0, channels[1], 256) - self.blocks = ModuleList( - [EfficientUNet.Block(channels[1], groups) for _ in range(blocks)]) - self.attention = (MultiheadAttention(2 * channels, heads, batch_first=True) - if heads > 0 else None) - - def forward(self, x: Tensor, y: Tensor, t: Tensor) -> Tensor: - x = self.conv(x) - x = self.condition(x, y, t) - for block in self.blocks: - x = block(x) - if self.attention is not None: - x, _ = self.attention(x, x, x, need_weights=False) - return x - - class UBlock(Module): - - def __init__( - self, - channels: tuple[int, int], - groups: int, - blocks: int, - heads: int, - stride: int = 1, - ) -> None: - super().__init__() - self.condition = EfficientUNet.Condition(0, channels[1], 256) - self.blocks = ModuleList( - [EfficientUNet.Block(channels[1], groups) for _ in range(blocks)]) - self.attention = MultiheadAttention(2 * channels, heads, batch_first=True) - self.conv = Conv2d(channels[0], channels[1], 3, stride, 1) - - def forward(self, x: Tensor, y: Tensor, t: Tensor) -> Tensor: - x = self.condition(x, y, t) - for block in self.blocks: - x = block(x) - x, _ = self.attention(x, x, x, need_weights=False) - x = self.conv(x) - return x - - def __init__( - self, - channels: Sequence[int], - strides: int | Sequence[int] = 2, - blocks: int | Sequence[int] = 2, - groups: int | Sequence[int] = 16, - heads: int | Sequence[int] = 8, - labels: int = 0, - ratio: int = 1, - hidden: int = 256, - ) -> None: - super().__init__() - if isinstance(strides, int): - strides = [strides] * (len(channels) - 1) - if isinstance(blocks, int): - blocks = [blocks] * (len(channels) - 1) - if isinstance(groups, int): - groups = [groups] * (len(channels) - 1) - if isinstance(heads, int): - heads = [heads] * (len(channels) - 1) - - self.conv = Conv2d(channels[0], channels[1], 3, 1, 1) - self.encoder = ModuleList([ - EfficientUNet.DBlock(channels, groups, blocks, heads, stride) - for channels, stride, blocks, groups, heads - in zip(zip(channels[:-1], channels[1:]), strides, blocks, groups, heads) - ]) # yapf: disable - self.bottleneck = ModuleList([ - EfficientUNet.DBlock(channels[0], groups, 1, heads, (2, 2)), - EfficientUNet.UBlock(channels[0], groups, 1, heads, (2, 2)) - ]) - self.encoder = ModuleList([ - EfficientUNet.DBlock(channels, groups, blocks, heads, stride) - for channels, stride, blocks, groups, heads - in zip(zip(channels[::-1], channels[-2::-1]), strides, blocks, groups, heads) - ]) # yapf: disable - self.linear = Linear(hidden, 2 * channels[1]) - self.rearrange = Rearrange("b (r c) h w -> r b c h w", r=ratio) - - def forward(self, x: Tensor, y: Tensor, t: Tensor) -> Tensor: - x = self.conv(x) - h = [x := block(x, y, t) for block in self.encoder] - for block in self.bottleneck: - x = self.bottleneck(x, y, t) - for block in self.decoder: - x = x + h.pop() - x = block(x, y, t) - return x - - class Transformer(Net): class Block(Module): @@ -287,7 +157,7 @@ def __init__( self, input: int, labels: int = 0, - ratio: int = 1, + parameters: int = 1, depth: int = 6, width: int = 256, heads: int = 8, @@ -299,8 +169,8 @@ def __init__( self.time = SinusoidalPositionalEmbedding(width) self.blocks = Sequential( *[Transformer.Block(width, heads) for _ in range(depth)]) - self.linear2 = Linear(width, input * ratio) - self.rearrange = Rearrange("b l (r e) -> r b l e", r=ratio) + self.linear2 = Linear(width, input * parameters) + self.rearrange = Rearrange("b l (p e) -> p b l e", p=parameters) def forward(self, x: Tensor, y: Tensor, t: Tensor) -> Tensor: x = self.linear1(x) + self.position(torch.arange(x.shape[1], device=x.device))