Skip to content

Commit

Permalink
Remove EfficientUNet, change params to match docs
Browse files Browse the repository at this point in the history
  • Loading branch information
cabralpinto committed Aug 26, 2023
1 parent c9f2079 commit 2234c41
Showing 1 changed file with 7 additions and 137 deletions.
144 changes: 7 additions & 137 deletions diffusion/net.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence, Optional
from typing import Sequence

import torch
from einops.layers.torch import Rearrange
Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(
self,
channels: Sequence[int],
labels: int = 0,
ratio: int = 1,
parameters: int = 1,
hidden: int = 256,
heads: int = 8,
groups: int = 16,
Expand Down Expand Up @@ -102,8 +102,8 @@ def __init__(
]) for channels_ in zip(channels[:1:-1], channels[-2::-1])
])
self.output = Sequential(
Conv2d(channels[1], ratio * channels[0], 3, 1, 1),
Rearrange("b (r c) h w -> r b c h w", r=ratio),
Conv2d(channels[1], parameters * channels[0], 3, 1, 1),
Rearrange("b (p c) h w -> p b c h w", p=parameters),
)

def forward(self, x: Tensor, y: Tensor, t: Tensor) -> Tensor:
Expand Down Expand Up @@ -131,136 +131,6 @@ def forward(self, x: Tensor, y: Tensor, t: Tensor) -> Tensor:
return x


class EfficientUNet(Net):
"""Implementation of Efficient U-Net from https://arxiv.org/abs/2205.11487v1"""

class Block(Module):

def __init__(self, channels: int, groups: int) -> None:
super().__init__()
self.conv = Conv2d(channels, channels, 1)
self.blocks = Sequential(
GroupNorm(groups, channels),
Swish(),
Conv2d(channels, channels, 3, 1, 1),
GroupNorm(groups, channels),
Swish(),
Conv2d(channels, channels, 3, 1, 1),
)

def forward(self, x: Tensor) -> Tensor:
return self.conv(x) + self.blocks(x)

class Condition(Module):

def __init__(self, labels: int, channels: int, hidden: int) -> None:
super().__init__()
self.label = Embedding(labels + 1, channels)
self.time = SinusoidalPositionalEmbedding(channels)
self.linear = Linear(hidden, 2 * channels)

class DBlock(Module):

def __init__(
self,
channels: tuple[int, int],
groups: int,
blocks: int,
heads: int,
stride: int = 1,
) -> None:
super().__init__()
self.conv = Conv2d(channels[0], channels[1], 3, stride, 1)
self.condition = EfficientUNet.Condition(0, channels[1], 256)
self.blocks = ModuleList(
[EfficientUNet.Block(channels[1], groups) for _ in range(blocks)])
self.attention = (MultiheadAttention(2 * channels, heads, batch_first=True)
if heads > 0 else None)

def forward(self, x: Tensor, y: Tensor, t: Tensor) -> Tensor:
x = self.conv(x)
x = self.condition(x, y, t)
for block in self.blocks:
x = block(x)
if self.attention is not None:
x, _ = self.attention(x, x, x, need_weights=False)
return x

class UBlock(Module):

def __init__(
self,
channels: tuple[int, int],
groups: int,
blocks: int,
heads: int,
stride: int = 1,
) -> None:
super().__init__()
self.condition = EfficientUNet.Condition(0, channels[1], 256)
self.blocks = ModuleList(
[EfficientUNet.Block(channels[1], groups) for _ in range(blocks)])
self.attention = MultiheadAttention(2 * channels, heads, batch_first=True)
self.conv = Conv2d(channels[0], channels[1], 3, stride, 1)

def forward(self, x: Tensor, y: Tensor, t: Tensor) -> Tensor:
x = self.condition(x, y, t)
for block in self.blocks:
x = block(x)
x, _ = self.attention(x, x, x, need_weights=False)
x = self.conv(x)
return x

def __init__(
self,
channels: Sequence[int],
strides: int | Sequence[int] = 2,
blocks: int | Sequence[int] = 2,
groups: int | Sequence[int] = 16,
heads: int | Sequence[int] = 8,
labels: int = 0,
ratio: int = 1,
hidden: int = 256,
) -> None:
super().__init__()
if isinstance(strides, int):
strides = [strides] * (len(channels) - 1)
if isinstance(blocks, int):
blocks = [blocks] * (len(channels) - 1)
if isinstance(groups, int):
groups = [groups] * (len(channels) - 1)
if isinstance(heads, int):
heads = [heads] * (len(channels) - 1)

self.conv = Conv2d(channels[0], channels[1], 3, 1, 1)
self.encoder = ModuleList([
EfficientUNet.DBlock(channels, groups, blocks, heads, stride)
for channels, stride, blocks, groups, heads
in zip(zip(channels[:-1], channels[1:]), strides, blocks, groups, heads)
]) # yapf: disable
self.bottleneck = ModuleList([
EfficientUNet.DBlock(channels[0], groups, 1, heads, (2, 2)),
EfficientUNet.UBlock(channels[0], groups, 1, heads, (2, 2))
])
self.encoder = ModuleList([
EfficientUNet.DBlock(channels, groups, blocks, heads, stride)
for channels, stride, blocks, groups, heads
in zip(zip(channels[::-1], channels[-2::-1]), strides, blocks, groups, heads)
]) # yapf: disable
self.linear = Linear(hidden, 2 * channels[1])
self.rearrange = Rearrange("b (r c) h w -> r b c h w", r=ratio)

def forward(self, x: Tensor, y: Tensor, t: Tensor) -> Tensor:
x = self.conv(x)
h = [x := block(x, y, t) for block in self.encoder]
for block in self.bottleneck:
x = self.bottleneck(x, y, t)
for block in self.decoder:
x = x + h.pop()
x = block(x, y, t)
return x


class Transformer(Net):

class Block(Module):
Expand All @@ -287,7 +157,7 @@ def __init__(
self,
input: int,
labels: int = 0,
ratio: int = 1,
parameters: int = 1,
depth: int = 6,
width: int = 256,
heads: int = 8,
Expand All @@ -299,8 +169,8 @@ def __init__(
self.time = SinusoidalPositionalEmbedding(width)
self.blocks = Sequential(
*[Transformer.Block(width, heads) for _ in range(depth)])
self.linear2 = Linear(width, input * ratio)
self.rearrange = Rearrange("b l (r e) -> r b l e", r=ratio)
self.linear2 = Linear(width, input * parameters)
self.rearrange = Rearrange("b l (p e) -> p b l e", p=parameters)

def forward(self, x: Tensor, y: Tensor, t: Tensor) -> Tensor:
x = self.linear1(x) + self.position(torch.arange(x.shape[1], device=x.device))
Expand Down

0 comments on commit 2234c41

Please sign in to comment.