Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can this code use for image super-resolution or restoration? #27

Open
wendeyy opened this issue Sep 22, 2024 · 6 comments
Open

Can this code use for image super-resolution or restoration? #27

wendeyy opened this issue Sep 22, 2024 · 6 comments

Comments

@wendeyy
Copy link

wendeyy commented Sep 22, 2024

Hi, I would like to apply this model for image super-resolution or restoration. Specifically, I want to try enhancing images that are blurred due to adverse weather conditions. Maybe this should be feasible i think and what adjustments should I make in the code? Thank you so much!

@explainingai-code
Copy link
Owner

Hello @wendeyy ,
I think you can use the code which does mask conditioned generation to perform super-resolution without requiring too many changes.
So say you want to train a model which given a 32x32 image generates a 8x resolution(256x256) image.

Here are steps that I believe should enable you to have something work out of the box(with celebhq dataset as example).

  1. Start with config created for celebhq https://github.com/explainingai-code/StableDiffusion-PyTorch/blob/main/config/celebhq_text_image_cond.yaml
  2. Remove the text conditioning(and keep mask conditioning) parameters from config - https://github.com/explainingai-code/StableDiffusion-PyTorch/blob/main/config/celebhq_text_image_cond.yaml#L24
  3. Modify the input channels to be 3 in the mask config - https://github.com/explainingai-code/StableDiffusion-PyTorch/blob/main/config/celebhq_text_image_cond.yaml#L32
  4. In the celeb_dataset file ensure the actual images (256x256) are added here - https://github.com/explainingai-code/StableDiffusion-PyTorch/blob/main/dataset/celeb_dataset.py#L84
  5. In get_mask method simply return a resized 32x32 version of this image here - https://github.com/explainingai-code/StableDiffusion-PyTorch/blob/main/dataset/celeb_dataset.py#L102 .
  6. Train ldm with this config using train_ddpm_cond script

Could you please try these and let me know if you run into any issues with them.

In case of any confusion, I also talk about mask conditioning and super resolution and what exact inputs the repo uses for mask conditioning in the conditional ldm video mentioned in README. Maybe just look at those parts and see if it helps to get a better understanding of the repos code
Mask Conditioning @ 18:47
Super resolution @ 25:50

@wendeyy
Copy link
Author

wendeyy commented Sep 24, 2024

@explainingai-code Thank you for your thorough explanation. I realize I may not described my task clearly. My goal is to deblur the image or enhance it to make it more clear than the original. Maybe it's a little bit different from mask-conditioned generation and super-resolution, but I will try the mask conditioned generation code first to see the results. If you have any suggestions or further ideas, I would greatly appreciate it!

@wendeyy
Copy link
Author

wendeyy commented Sep 26, 2024

Hi @explainingai-code, I've made changes to the code following your instructions and tried to run it. Here’s what I did.

  1. Remove the text conditioning, keep mask conditioning parameters and change input channels to 3
    condition_config: condition_types: ['image'] image_condition_config: image_condition_input_channels: 3 image_condition_output_channels: 3 image_condition_h : 512 image_condition_w : 512 cond_drop_prob: 0.1

  2. My data file structure is like the following,
    StableDiffusion-PyTorch
    -> data
    -> CelebAMask-HQ
    -> CelebA-HQ-img
    -> CelebA-HQ-img copy
    -> CelebAMask-HQ-mask-anno
    -> 0/1/2/3.../14
    -> *.png
    -> CelebAMask-HQ-mask
    -> *.png
    i load some celebahq images and resize thenm to size 32x32, saving in CelebA-HQ-img. and resize them to 256x256, saving in CelebAMask-HQ-mask to make sure here:
    if 'image' in self.condition_types: im_name = int(os.path.split(fname)[1].split('.')[0]) masks.append(os.path.join(im_path, 'CelebAMask-HQ-mask', '{}.png'.format(im_name)))
    read the 256x256 actual images.

  3. In get_mask method simply return a resized 32x32 image
    def get_mask(self, index): mask_im = Image.open(self.masks[index]) mask_im = mask_im.resize((32, 32)) return mask_im

not sure if these changes are correct, and error came out, seems I have to train the vae first? I would very appreciate any suggestions you have.
image

@explainingai-code
Copy link
Owner

Yes, since this is latent diffusion model, we would need to train a VAE(but vae on celebhq should not require more than 4-5 epochs to get a decent result).
I have a trained VAE checkpoint but that is for 128x128 images and with a downscale factor of 4(latent size will be 32x32).
If it helps I can share that checkpoint as well as changes you would need in config to work with that.

Regarding the conditioning changes that you have mentioned, they seem fine.
But if you are looking specifically for restoration, then wouldn't it be better that rather than resizing 256x256 to 32x32, you first resize it to 16x16 and then to 32x32. That should blur the condition image. Then you would be training the model to denoise a latent image conditioned on a blurry image(32x32) in pixel space.
And that way during inference, when you pass in your blurry image with a random noise sample, it should be able to generate a latent which when passed to decoder of vae gives us a deblurred version of the blurry condition image.

As a disclaimer I havent ever trained a deblurring model or read papers on that topic, so this is just something that I think should intuitively work(but not sure).

@wendeyy
Copy link
Author

wendeyy commented Oct 8, 2024

I run train_vqvae.py for 5 epochs, i don't know if it is ok. then i run train_ddpm_cond.py. but an error occurred stating that it couldn't find the images in CelebAMask-HQ-mask. i have ensure the images are in CelebA-HQ-img and CelebAMask-HQ-mask. image

@explainingai-code
Copy link
Owner

For autoencoder, there should be a folder created vqvae_autoencoder_samples(inside /mnt/StableDiffusion-PyTorch-main/celebhq), that would have reconstruction for images generated during training, just check the last image in that folder to see the quality of reconstruction.

Regarding the error, as long as the path where the code is trying to load the image from, as long that is a valid path, this error should not come.
Can you double check if the MASK folder indeed has '.png' files. I am thinking since you would have simply copied celebhq images(which are '.jpg') inside the MASL folder, the masks also would be '.jpg', whereas the script assumes mask to be '.png' files. Can you check that once, and if thats indeed the issue then just changing the code for get_mask method to add '.jpg' extension instead of '.png' should work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants