Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Confirmation required), might be a bug, imgnet branch, ema's not in eval() mode at training #21

Open
swimmincatt35 opened this issue Feb 19, 2025 · 1 comment

Comments

@swimmincatt35
Copy link

swimmincatt35 commented Feb 19, 2025

Hi, I really love your work! I am wondering if the following is an error related to ema's and the EDM2 in '''ect/training/ct_training_loop.py'''.

In the imgnet branch, ECT builds on top of the self normalizing architecture EDM2. In particular, I am referring to the following code in '''ect/training/networks_edm2.py''',

@persistence.persistent_class
class MPConv(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel):
        super().__init__()
        self.out_channels = out_channels
        self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
        self.force_wn = True

    def forward(self, x, gain=1):
        w = self.weight.to(torch.float32)
        if self.training and self.force_wn:
            with torch.no_grad():
                self.weight.copy_(normalize(w)) # forced weight normalization
        w = normalize(w) # traditional weight normalization
        w = w * (gain / np.sqrt(w[0].numel())) # magnitude-preserving scaling
        w = w.to(x.dtype)
        if w.ndim == 2:
            return x @ w.t()
        assert w.ndim == 4
        return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,))

Therefore, I think we should make sure that all ema's in the '''ect/training/ct_training_loop.py''' should be in eval() mode. Otherwise, the ema weights will be accidentally updated after it samples images. For instance, the following section would cause the ema weights to self-correct/self-normalize.

 if dist.get_rank() == 0:
        dist.print0('Exporting sample images...')
        grid_size, images, labels = setup_snapshot_image_grid(training_set=dataset_obj)
        save_image_grid(images, os.path.join(run_dir, 'data.png'), drange=[0,255], grid_size=grid_size)
        
        grid_z = torch.randn([labels.shape[0], net.img_channels, net.img_resolution, net.img_resolution], device=device)
        grid_z = grid_z.split(batch_gpu)
        
        grid_c = torch.from_numpy(labels).to(device)
        grid_c = grid_c.split(batch_gpu)
        
        ema_list = ema.get()
        for ema_net, ema_suffix in ema_list:
            images = [generator_fn(ema_net, z, c).cpu() for z, c in zip(grid_z, grid_c)]
            images = torch.cat(images).numpy()
            save_image_grid(images, os.path.join(run_dir, f'model_init{ema_suffix}.png'), drange=[-1,1], grid_size=grid_size)
        del images

and in the training loop:

# Sample Img
        if (sample_ticks is not None) and (done or cur_tick % sample_ticks == 0) and dist.get_rank() == 0:
            dist.print0('Exporting sample images...')
            ema_list = ema.get()
            for ema_net, ema_suffix in ema_list:
                images = [generator_fn(ema_net, z, c).cpu() for z, c in zip(grid_z, grid_c)]
                images = torch.cat(images).numpy()
                save_image_grid(images, os.path.join(run_dir, f'1_step_{cur_tick:06d}_{cur_nimg//1000:07d}{ema_suffix}.png'), drange=[-1,1], grid_size=grid_size)

                few_step_fn = functools.partial(generator_fn, mid_t=mid_t)
                images = [few_step_fn(ema_net, z, c).cpu() for z, c in zip(grid_z, grid_c)]
                images = torch.cat(images).numpy()
                save_image_grid(images, os.path.join(run_dir, f'2_step_{cur_tick:06d}_{cur_nimg//1000:07d}{ema_suffix}.png'), drange=[-1,1], grid_size=grid_size)
            del images

I discovered this potential issue when I was finetuning a pre-trained ECM, and the model produced different "model_init{ema_suffix}.png" in the beginning. I believe they shouldn't look different before any training. But please feel free to correct me if I am wrong. Maybe I missed something crucial in the original EDM2 paper? Or perhaps this is trivial?

This would not be an issue for evaluation. At FID evaluation, the network will be deep copied and set to eval mode. Seemed to me that only these sampling sections in the traiining code could lead to undefined/unwanted behaviour, as it changes the ema's unintentionally.

The following is my proposed solution to this issue. Add self.ema.eval() in the end of initialization, in ect/training/phema.py, for both traditonal and power emas.

#----------------------------------------------------------------------------
# Class for tracking traditional EMA during training.

class TraditionalEMA:
    @torch.no_grad()
    def __init__(self, net, ema_beta=0.9999, halflife_Mimg=None, rampup_ratio=None):
        self.net = net
        self.ema_beta = ema_beta
        self.halflife_Mimg = halflife_Mimg
        self.rampup_ratio = rampup_ratio
        self.ema = copy.deepcopy(net)
        # NOTE(CH) EDM2 unet normalizes for every forward pass at training! 
        # Set to eval() mode for consistent behaviour
        self.ema.eval()
#----------------------------------------------------------------------------
# Class for tracking power function EMA during the training.

class PowerFunctionEMA:
    @torch.no_grad()
    def __init__(self, net, stds=[0.010, 0.050, 0.100]):
        self.net = net
        self.stds = stds
        self.emas = [copy.deepcopy(net) for _std in stds]
        # NOTE(CH) EDM2 unet normalizes for every forward pass at training! 
        # Set to eval() mode for consistent behaviour
        for ema in self.emas:
            ema.eval()

Again, great work on ECT! I am a huge fan of your work! Looking forward to further discussion.

@Gsunshine
Copy link
Member

Hi @swimmincatt35 ,

Thanks for catching the potential issue! I believe you're right since my CIFAR-10 code base has turned on the eval mode but not in the ImgNet codebase.

Could you please create a pull request to fix it? Or I can fix it later, too.

Thanks,
Zhengyang

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants