From ce239ac7284ab4a1e27589610261271981e98290 Mon Sep 17 00:00:00 2001 From: WolframRhodium Date: Tue, 23 Apr 2024 11:31:14 +0800 Subject: [PATCH] scripts/vsmlrt.py: add support for SwinIR models https://github.com/AmusementClub/vs-mlrt/issues/54 --- scripts/vsmlrt.py | 125 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 124 insertions(+), 1 deletion(-) diff --git a/scripts/vsmlrt.py b/scripts/vsmlrt.py index 4e87d4d..fd48236 100644 --- a/scripts/vsmlrt.py +++ b/scripts/vsmlrt.py @@ -1,4 +1,4 @@ -__version__ = "3.20.9" +__version__ = "3.20.10" __all__ = [ "Backend", "BackendV2", @@ -10,6 +10,7 @@ "RIFE", "RIFEModel", "RIFEMerge", "SAFA", "SAFAModel", "SAFAAdaptiveMode", "SCUNet", "SCUNetModel", + "SwinIR", "SwinIRModel", "inference" ] @@ -1498,6 +1499,128 @@ def SCUNet( return clip +@enum.unique +class SwinIRModel(enum.IntEnum): + lightweightSR_DIV2K_s64w8_SwinIR_S_x2 = 0 + lightweightSR_DIV2K_s64w8_SwinIR_S_x3 = 1 + lightweightSR_DIV2K_s64w8_SwinIR_S_x4 = 2 + realSR_BSRGAN_DFOWMFC_s64w8_SwinIR_L_x4_GAN = 3 + # unused + realSR_BSRGAN_DFOWMFC_s64w8_SwinIR_L_x4_PSNR = 5 + classicalSR_DF2K_s64w8_SwinIR_M_x2 = 6 + classicalSR_DF2K_s64w8_SwinIR_M_x3 = 7 + classicalSR_DF2K_s64w8_SwinIR_M_x4 = 8 + classicalSR_DF2K_s64w8_SwinIR_M_x8 = 9 + realSR_BSRGAN_DFO_s64w8_SwinIR_M_x2_GAN = 10 + realSR_BSRGAN_DFO_s64w8_SwinIR_M_x2_PSNR = 11 + realSR_BSRGAN_DFO_s64w8_SwinIR_M_x4_GAN = 12 + realSR_BSRGAN_DFO_s64w8_SwinIR_M_x4_PSNR = 13 + grayDN_DFWB_s128w8_SwinIR_M_noise15 = 14 + grayDN_DFWB_s128w8_SwinIR_M_noise25 = 15 + grayDN_DFWB_s128w8_SwinIR_M_noise50 = 16 + colorDN_DFWB_s128w8_SwinIR_M_noise15 = 17 + colorDN_DFWB_s128w8_SwinIR_M_noise25 = 18 + colorDN_DFWB_s128w8_SwinIR_M_noise50 = 19 + CAR_DFWB_s126w7_SwinIR_M_jpeg10 = 20 + CAR_DFWB_s126w7_SwinIR_M_jpeg20 = 21 + CAR_DFWB_s126w7_SwinIR_M_jpeg30 = 22 + CAR_DFWB_s126w7_SwinIR_M_jpeg40 = 23 + colorCAR_DFWB_s126w7_SwinIR_M_jpeg10 = 24 + colorCAR_DFWB_s126w7_SwinIR_M_jpeg20 = 25 + colorCAR_DFWB_s126w7_SwinIR_M_jpeg30 = 26 + colorCAR_DFWB_s126w7_SwinIR_M_jpeg40 = 27 + + +def SwinIR( + clip: vs.VideoNode, + tiles: typing.Optional[typing.Union[int, typing.Tuple[int, int]]] = None, + tilesize: typing.Optional[typing.Union[int, typing.Tuple[int, int]]] = None, + overlap: typing.Optional[typing.Union[int, typing.Tuple[int, int]]] = None, + model: SwinIRModel = SwinIRModel.lightweightSR_DIV2K_s64w8_SwinIR_S_x2, + backend: backendT = Backend.OV_CPU() +) -> vs.VideoNode: + """ SwinIR: Image Restoration Using Swin Transformer """ + + func_name = "vsmlrt.SwinIR" + + if not isinstance(clip, vs.VideoNode): + raise TypeError(f'{func_name}: "clip" must be a clip!') + + if clip.format.sample_type != vs.FLOAT or clip.format.bits_per_sample not in [16, 32]: + raise ValueError(f"{func_name}: only constant format 16/32 bit float input supported") + + if not isinstance(model, int) or model not in SwinIRModel.__members__.values(): + raise ValueError(f'{func_name}: invalid "model"') + + if model in range(14, 17) or model in range(20, 24): + if clip.format.color_family != vs.GRAY: + raise ValueError(f'{func_name}: "clip" must be of GRAY color family') + elif clip.format.color_family != vs.RGB: + raise ValueError(f'{func_name}: "clip" must be of RGB color family') + + if overlap is None: + overlap_w = overlap_h = 16 + elif isinstance(overlap, int): + overlap_w = overlap_h = overlap + else: + overlap_w, overlap_h = overlap + + multiple = 1 + + (tile_w, tile_h), (overlap_w, overlap_h) = calc_tilesize( + tiles=tiles, tilesize=tilesize, + width=clip.width, height=clip.height, + multiple=multiple, + overlap_w=overlap_w, overlap_h=overlap_h + ) + + if tile_w % multiple != 0 or tile_h % multiple != 0: + raise ValueError( + f'{func_name}: tile size must be divisible by {multiple} ({tile_w}, {tile_h})' + ) + + backend = init_backend( + backend=backend, + trt_opt_shapes=(tile_w, tile_h) + ) + + if model < 4: + model_name = tuple(SwinIRModel.__members__)[model] + else: + model_name = tuple(SwinIRModel.__members__)[model - 1] + + model_name = model_name.replace("SwinIR_", "SwinIR-") + + if model in range(3): + model_name = f"002_{model_name}" + elif model in (3, 5): + model_name = f"003_{model_name}" + elif model in range(6, 10): + model_name = f"001_{model_name}" + elif model in range(10, 14): + model_name = f"003_{model_name}" + elif model in range(14, 17): + model_name = f"004_{model_name}" + elif model in range(17, 20): + model_name = f"005_{model_name}" + elif model in range(20, 28): + model_name = f"006_{model_name}" + + network_path = os.path.join( + models_path, + "swinir", + f"{model_name}.onnx" + ) + + clip = inference_with_fallback( + clips=[clip], network_path=network_path, + overlap=(overlap_w, overlap_h), tilesize=(tile_w, tile_h), + backend=backend + ) + + return clip + + def get_engine_path( network_path: str, min_shapes: typing.Tuple[int, int],