Skip to content

Commit

Permalink
🧗
Browse files Browse the repository at this point in the history
  • Loading branch information
johndpope committed Aug 12, 2024
1 parent af2a39f commit 26111c4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 40 deletions.
2 changes: 1 addition & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ training:
lambda_gp: 10 # Gradient penalty coefficient
lambda_mse: 1.0
n_critic: 2 # Number of discriminator updates per generator update
clip_grad_norm: 0.5 # Maximum norm for gradient clipping
clip_grad_norm: 0.75 # Maximum norm for gradient clipping
r1_gamma: 10
r1_interval: 16
label_smoothing: 0.1
Expand Down
47 changes: 8 additions & 39 deletions resblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,6 @@ def debug_print(*args, **kwargs):
if DEBUG:
print(*args, **kwargs)

# class UpConvResBlock(nn.Module):
# def __init__(self, in_channels, out_channels):
# super().__init__()
# self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
# self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# self.bn1 = nn.BatchNorm2d(out_channels)
# self.relu = nn.ReLU(inplace=True)
# self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
# self.feat_res_block1 = FeatResBlock(out_channels)
# self.feat_res_block2 = FeatResBlock(out_channels)

# def forward(self, x):
# x = self.upsample(x)
# x = self.conv1(x)
# x = self.bn1(x)
# x = self.relu(x)
# x = self.conv2(x)
# x = self.feat_res_block1(x)
# x = self.feat_res_block2(x)
# return x


# class DownConvResBlock(nn.Module):
Expand Down Expand Up @@ -131,36 +111,25 @@ def debug_print(*args, **kwargs):


class UpConvResBlock(nn.Module):
def __init__(self, in_channels, out_channels, dropout_rate=0.1):
def __init__(self, in_channels, out_channels):
super().__init__()
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.dropout = nn.Dropout2d(dropout_rate)

self.feat_res_block1 = FeatResBlock(out_channels)
self.feat_res_block2 = FeatResBlock(out_channels)

self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

def forward(self, x):
x = self.upsample(x)
residual = self.residual_conv(x)

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = out + residual
out = self.relu(out)
out = self.dropout(out)
out = self.feat_res_block1(out)
out = self.feat_res_block2(out)
return out
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.feat_res_block1(x)
x = self.feat_res_block2(x)
return x



Expand Down

0 comments on commit 26111c4

Please sign in to comment.