-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to train/load S/M/L CNN models in tensorflow? #220
Comments
Yes, you would need to load the original models in TensorFlow.
Someone has converted the original Anime4K models into Core ML models. I can provide you the link. The ones you're looking for are under You would need to convert them to TensorFlow, then create a Python or Jupyter Notebook script to load the weights and models. You can use the models to fine-tune and train your own, better model. NoteI have not converted or trained the models myself, and cannot guarantee success. I can only provide general steps, and you will need to do your own research. Supposedly, the steps to convert between Core ML and TensorFlow should be relatively straightforward. The training process itself should be more or less the same as training any other TensorFlow or ESRGAN models.Sample Python Script:
|
@Tama47 The training code located in \tensorflow dir is for the restore or upscale model? And if it is the restore, is it easy to change it to the "upscale" model to train? |
@Tama47 From what I've researched so far, there is no way to convert current version of MLModel to TF2 or ONNX. However, I managed to get Netron working and also loading weight:
|
I can convert some GLSL files to PyTorch now but still stuck at converting the weight. Here is the code if anyone interested: |
It contains both. In Gen_Shader.ipynb, SR1Model generate restore models while SR2Model generate upscale ones |
Thanks. Apparently the models there for for M size shaders. Do you happen to know what parameter values (block_depth, etc?) I should use to get the S/L sizes? Also, the code uses epochs=1 (3 times). Should I change them to like 100? I noticed the loss doesn't really decrease. |
Ig you can figure out the block_depth with a model's components Size S:
Size M
Size L:
Size VL:
Size UL:
|
Thanks. I have those architecture. But do you know what to pass to this function to get each of those S, L, VL sizes? I need it for training.
|
@arianaa30 My main library is PyTorch so Idk tbh |
@Fannovel16 Btw, do you know how to measure SSIM/PSNR of what Anime4K shaders provide me (upscaled version of low-res image) vs the original high resolution image? Is there a way to measure them? |
@arianaa30 You can pass images to mpv, ffmpeg with compiled libplacebo using commands and save upscaled images. I'm not sure how to do that tho since I'm not familiar much with ffmpeg or mpv. Alternatively, you can try Anime4K-GPU and use |
Btw I have no luck on converting GLSL shaders into actual models in PyTorch. I believe the problem is transforming the weight. |
Hmm ok thanks. The problem is we apply multiple anime4k shaders (restore, upscale, restore, ...). Not sure if we can do that.. |
@arianaa30 It's possible: mpv-player/mpv#9589. But now you mentioned it, I kinda wonder how A4K shaders were actually trained. |
@Fannovel16 Yeah the training has some unknowns. Using the Tensorflow script, I trained a model/shader by calling SR2Model() function, and it works. But when I trained the SR1Model (which should be the Restore), the h5 model training works. But when trying to convert with Gen_Shader.py, it shows me a "Shape Mismatch" error. Have you experienced it before?
|
@Fannovel16 |
Is the displayed image the upscaled output? Can we apply multiple shaders as well? |
Of course, just string the models together like this: |
@kato-megumi Thanks! It seems like I got the CreLU formula wrong |
Great thanks. Can we simply add other shaders to the list as well? I want to use |
@arianaa30 Here it is def get_luma(x):
x = x[:, 0] * 0.299 + x[:, 1] * 0.587 + x[:, 2] * 0.114
x = x.unsqueeze(1)
return x
class MaxPoolKeepShape(nn.Module):
def __init__(self, kernel_size, stride=None):
super(MaxPoolKeepShape, self).__init__()
self.kernel_size = kernel_size
self.stride = stride if stride is not None else kernel_size
def forward(self, x):
batch_size, channels, height, width = x.size()
kernel_height, kernel_width = self.kernel_size
pad_height = (((height - 1) // self.stride + 1) - 1) * self.stride + kernel_height - height
pad_width = (((width - 1) // self.stride + 1) - 1) * self.stride + kernel_width - width
x = F.pad(x, (pad_width // 2, pad_width - pad_width // 2, pad_height // 2, pad_height - pad_height // 2))
x = F.max_pool2d(x, kernel_size=self.kernel_size, stride=self.stride)
return x
class ClampHighlight(nn.Module):
def __init__(self):
super(ClampHighlight, self).__init__()
self.max_pool = MaxPoolKeepShape(kernel_size=(5, 5), stride=1)
def forward(self, shader_img, orig_img):
curr_luma = get_luma(shader_img)
statsmax = self.max_pool(get_luma(orig_img))
if statsmax.shape != curr_luma.shape:
statsmax = F.interpolate(statsmax, curr_luma.shape[2:4])
new_luma = torch.min(curr_luma, statsmax)
return shader_img - (curr_luma - new_luma)
new_img = ClampHighlight()(out[None], image2)
display(to_pil(new_img[0])) |
|
Oh so the first block iterates x-axis while the second block iterates y-axis
What is PREKERNEL? I assumed it is the same as MAIN as the mpv doc is a bit ambiguous |
Yeah, it reduce computation cost compare to find max of 25 pixel in single pass.
In anime4k doc about ClampHighlight: "Computes and saves image statistics at the location it is placed in the shader stage, then clamps the image highlights at the end after all the shaders to prevent overshoot and reduce ringing."
I think it refers to the image right before mpv performs internal scaling. |
I added ClampHightlight, AutoDownscalePre, automatic glsl downloading and pipeline class for convenience: |
Great I will try it. |
I recommend using https://github.com/muslll/neosr/ to train model. |
@Fannovel16 Btw do you have a training code for the PyTorch models? Would you be able to share? |
@arianaa30 No but you can use my notebook to get the model and randomize its parameters to train |
Should I fine tune it (only train last layers) or train the whole network? |
@arianaa30 I forgot to test 😅 . It works now
Anime4K's CNN networks are pretty small so training from scratch is a better choice, imo. |
@Fannovel16 @kato-megumi bumping up this thread: Have any of you had any success training the PyTorch models, at least just to test out? It is so weird |
@arianaa30 |
Is there a way to train/load S/M/L CNN models in tensorflow? I am interested in experimenting a bit with these models in tensorflow or onnxruntime. I see that there is one specific model in the
tensorflow
directory, but I am not sure which one is it.The text was updated successfully, but these errors were encountered: