Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unet是训好的吗?没看见代码里有训练Unet的地方 #31

Open
dream-in-night opened this issue Nov 24, 2023 · 1 comment
Open

Comments

@dream-in-night
Copy link

为什么代码里只有训练的噪声损失?Unet怎么搞啊
05AFF5E4

`class GaussianDiffusionTrainer(nn.Module):
def init(self, model, beta_1, beta_T, T):
super().init()
# beta_1 : 1e-4
# beta_T : 0.02
# T : 1000
self.model = model
self.T = T

    self.register_buffer(
        'betas', torch.linspace(beta_1, beta_T, T).double())
    # self.betas : [1000], from 1e-4 to 0.02, linearly, 1000 points, double type
    alphas = 1. - self.betas
    alphas_bar = torch.cumprod(alphas, dim=0)
    # alphas_bar : [1000], from 1 to 0.98, linearly, 1000 points, double type

    # calculations for diffusion q(x_t | x_{t-1}) and others
    self.register_buffer(
        'sqrt_alphas_bar', torch.sqrt(alphas_bar))
    self.register_buffer(
        'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))

def forward(self, x_0):
    """
    Algorithm 1.
    """
    # x_0 : [80, 3, 32, 32]
    t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
    # t : [80]
    noise = torch.randn_like(x_0)
    # extract: 
    x_t = (
        extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
        extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
    x = self.model(x_t, t)
    # x: [80, 3, 32, 32]
    loss = F.mse_loss(x, noise, reduction='none')
    return loss

`

@mileret
Copy link

mileret commented Nov 29, 2023

train.py里面
‘’‘optimizer.zero_grad()
x_0 = images.to(device)
loss = trainer(x_0).sum() / 1000.
loss.backward()’‘’
就是在训练U-Net啊

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants