You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I really love your work! I am wondering if the following is an error related to ema's and the EDM2 in '''ect/training/ct_training_loop.py'''.
In the imgnet branch, ECT builds on top of the self normalizing architecture EDM2. In particular, I am referring to the following code in '''ect/training/networks_edm2.py''',
@persistence.persistent_class
class MPConv(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel):
super().__init__()
self.out_channels = out_channels
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
self.force_wn = True
def forward(self, x, gain=1):
w = self.weight.to(torch.float32)
if self.training and self.force_wn:
with torch.no_grad():
self.weight.copy_(normalize(w)) # forced weight normalization
w = normalize(w) # traditional weight normalization
w = w * (gain / np.sqrt(w[0].numel())) # magnitude-preserving scaling
w = w.to(x.dtype)
if w.ndim == 2:
return x @ w.t()
assert w.ndim == 4
return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,))
Therefore, I think we should make sure that all ema's in the '''ect/training/ct_training_loop.py''' should be in eval() mode. Otherwise, the ema weights will be accidentally updated after it samples images. For instance, the following section would cause the ema weights to self-correct/self-normalize.
if dist.get_rank() == 0:
dist.print0('Exporting sample images...')
grid_size, images, labels = setup_snapshot_image_grid(training_set=dataset_obj)
save_image_grid(images, os.path.join(run_dir, 'data.png'), drange=[0,255], grid_size=grid_size)
grid_z = torch.randn([labels.shape[0], net.img_channels, net.img_resolution, net.img_resolution], device=device)
grid_z = grid_z.split(batch_gpu)
grid_c = torch.from_numpy(labels).to(device)
grid_c = grid_c.split(batch_gpu)
ema_list = ema.get()
for ema_net, ema_suffix in ema_list:
images = [generator_fn(ema_net, z, c).cpu() for z, c in zip(grid_z, grid_c)]
images = torch.cat(images).numpy()
save_image_grid(images, os.path.join(run_dir, f'model_init{ema_suffix}.png'), drange=[-1,1], grid_size=grid_size)
del images
and in the training loop:
# Sample Img
if (sample_ticks is not None) and (done or cur_tick % sample_ticks == 0) and dist.get_rank() == 0:
dist.print0('Exporting sample images...')
ema_list = ema.get()
for ema_net, ema_suffix in ema_list:
images = [generator_fn(ema_net, z, c).cpu() for z, c in zip(grid_z, grid_c)]
images = torch.cat(images).numpy()
save_image_grid(images, os.path.join(run_dir, f'1_step_{cur_tick:06d}_{cur_nimg//1000:07d}{ema_suffix}.png'), drange=[-1,1], grid_size=grid_size)
few_step_fn = functools.partial(generator_fn, mid_t=mid_t)
images = [few_step_fn(ema_net, z, c).cpu() for z, c in zip(grid_z, grid_c)]
images = torch.cat(images).numpy()
save_image_grid(images, os.path.join(run_dir, f'2_step_{cur_tick:06d}_{cur_nimg//1000:07d}{ema_suffix}.png'), drange=[-1,1], grid_size=grid_size)
del images
I discovered this potential issue when I was finetuning a pre-trained ECM, and the model produced different "model_init{ema_suffix}.png" in the beginning. I believe they shouldn't look different before any training. But please feel free to correct me if I am wrong. Maybe I missed something crucial in the original EDM2 paper? Or perhaps this is trivial?
This would not be an issue for evaluation. At FID evaluation, the network will be deep copied and set to eval mode. Seemed to me that only these sampling sections in the traiining code could lead to undefined/unwanted behaviour, as it changes the ema's unintentionally.
The following is my proposed solution to this issue. Add self.ema.eval() in the end of initialization, in ect/training/phema.py, for both traditonal and power emas.
#----------------------------------------------------------------------------
# Class for tracking traditional EMA during training.
class TraditionalEMA:
@torch.no_grad()
def __init__(self, net, ema_beta=0.9999, halflife_Mimg=None, rampup_ratio=None):
self.net = net
self.ema_beta = ema_beta
self.halflife_Mimg = halflife_Mimg
self.rampup_ratio = rampup_ratio
self.ema = copy.deepcopy(net)
# NOTE(CH) EDM2 unet normalizes for every forward pass at training!
# Set to eval() mode for consistent behaviour
self.ema.eval()
#----------------------------------------------------------------------------
# Class for tracking power function EMA during the training.
class PowerFunctionEMA:
@torch.no_grad()
def __init__(self, net, stds=[0.010, 0.050, 0.100]):
self.net = net
self.stds = stds
self.emas = [copy.deepcopy(net) for _std in stds]
# NOTE(CH) EDM2 unet normalizes for every forward pass at training!
# Set to eval() mode for consistent behaviour
for ema in self.emas:
ema.eval()
Again, great work on ECT! I am a huge fan of your work! Looking forward to further discussion.
The text was updated successfully, but these errors were encountered:
Thanks for catching the potential issue! I believe you're right since my CIFAR-10 code base has turned on the eval mode but not in the ImgNet codebase.
Could you please create a pull request to fix it? Or I can fix it later, too.
Hi, I really love your work! I am wondering if the following is an error related to ema's and the EDM2 in '''ect/training/ct_training_loop.py'''.
In the imgnet branch, ECT builds on top of the self normalizing architecture EDM2. In particular, I am referring to the following code in '''ect/training/networks_edm2.py''',
Therefore, I think we should make sure that all ema's in the '''ect/training/ct_training_loop.py''' should be in eval() mode. Otherwise, the ema weights will be accidentally updated after it samples images. For instance, the following section would cause the ema weights to self-correct/self-normalize.
and in the training loop:
I discovered this potential issue when I was finetuning a pre-trained ECM, and the model produced different "model_init{ema_suffix}.png" in the beginning. I believe they shouldn't look different before any training. But please feel free to correct me if I am wrong. Maybe I missed something crucial in the original EDM2 paper? Or perhaps this is trivial?
This would not be an issue for evaluation. At FID evaluation, the network will be deep copied and set to eval mode. Seemed to me that only these sampling sections in the traiining code could lead to undefined/unwanted behaviour, as it changes the ema's unintentionally.
The following is my proposed solution to this issue. Add self.ema.eval() in the end of initialization, in ect/training/phema.py, for both traditonal and power emas.
Again, great work on ECT! I am a huge fan of your work! Looking forward to further discussion.
The text was updated successfully, but these errors were encountered: