We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
为什么代码里只有训练的噪声损失?Unet怎么搞啊
`class GaussianDiffusionTrainer(nn.Module): def init(self, model, beta_1, beta_T, T): super().init() # beta_1 : 1e-4 # beta_T : 0.02 # T : 1000 self.model = model self.T = T
self.register_buffer( 'betas', torch.linspace(beta_1, beta_T, T).double()) # self.betas : [1000], from 1e-4 to 0.02, linearly, 1000 points, double type alphas = 1. - self.betas alphas_bar = torch.cumprod(alphas, dim=0) # alphas_bar : [1000], from 1 to 0.98, linearly, 1000 points, double type # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer( 'sqrt_alphas_bar', torch.sqrt(alphas_bar)) self.register_buffer( 'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar)) def forward(self, x_0): """ Algorithm 1. """ # x_0 : [80, 3, 32, 32] t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device) # t : [80] noise = torch.randn_like(x_0) # extract: x_t = ( extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 + extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise) x = self.model(x_t, t) # x: [80, 3, 32, 32] loss = F.mse_loss(x, noise, reduction='none') return loss
`
The text was updated successfully, but these errors were encountered:
train.py里面 ‘’‘optimizer.zero_grad() x_0 = images.to(device) loss = trainer(x_0).sum() / 1000. loss.backward()’‘’ 就是在训练U-Net啊
Sorry, something went wrong.
No branches or pull requests
为什么代码里只有训练的噪声损失?Unet怎么搞啊
`class GaussianDiffusionTrainer(nn.Module):
def init(self, model, beta_1, beta_T, T):
super().init()
# beta_1 : 1e-4
# beta_T : 0.02
# T : 1000
self.model = model
self.T = T
`
The text was updated successfully, but these errors were encountered: