From 6af764661c6e078aef957de277fe7e46ee6643ea Mon Sep 17 00:00:00 2001 From: WolframRhodium Date: Thu, 9 May 2024 13:48:20 +0800 Subject: [PATCH] scripts/vsmlrt.py: add support for artcnn models --- scripts/vsmlrt.py | 102 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 101 insertions(+), 1 deletion(-) diff --git a/scripts/vsmlrt.py b/scripts/vsmlrt.py index b16a183..84501f1 100644 --- a/scripts/vsmlrt.py +++ b/scripts/vsmlrt.py @@ -1,4 +1,4 @@ -__version__ = "3.20.12" +__version__ = "3.20.13" __all__ = [ "Backend", "BackendV2", @@ -11,6 +11,7 @@ "SAFA", "SAFAModel", "SAFAAdaptiveMode", "SCUNet", "SCUNetModel", "SwinIR", "SwinIRModel", + "ArtCNN", "ArtCNNModel", "inference" ] @@ -1622,6 +1623,105 @@ def SwinIR( return clip +@enum.unique +class ArtCNNModel(enum.IntEnum): + ArtCNN_C4F32 = 0 + ArtCNN_C4F32_DS = 1 + ArtCNN_C16F64 = 2 + ArtCNN_C16F64_DS = 3 + ArtCNN_C4F32_Chroma = 4 + ArtCNN_C16F64_Chroma = 5 + + +def ArtCNN( + clip: vs.VideoNode, + tiles: typing.Optional[typing.Union[int, typing.Tuple[int, int]]] = None, + tilesize: typing.Optional[typing.Union[int, typing.Tuple[int, int]]] = None, + overlap: typing.Optional[typing.Union[int, typing.Tuple[int, int]]] = None, + model: ArtCNNModel = ArtCNNModel.ArtCNN_C16F64, + backend: backendT = Backend.OV_CPU() +) -> vs.VideoNode: + """ ArtCNN (https://github.com/Artoriuz/ArtCNN) """ + + func_name = "vsmlrt.ArtCNN" + + if not isinstance(clip, vs.VideoNode): + raise TypeError(f'{func_name}: "clip" must be a clip!') + + if clip.format.sample_type != vs.FLOAT or clip.format.bits_per_sample not in [16, 32]: + raise ValueError(f"{func_name}: only constant format 16/32 bit float input supported") + + if not isinstance(model, int) or model not in ArtCNNModel.__members__.values(): + raise ValueError(f'{func_name}: invalid "model"') + + if model in range(4, 6): + if clip.format.color_family != vs.YUV: + raise ValueError(f'{func_name}: "clip" must be of YUV color family') + if clip.format.subsampling_h != 0 or clip.format.subsampling_w != 0: + raise ValueError( + f'{func_name}: "clip" must be without subsampling!' + 'Bilinear upsampling is recommended.' + ) + elif clip.format.color_family != vs.GRAY: + raise ValueError(f'{func_name}: "clip" must be of GRAY color family') + + if overlap is None: + overlap_w = overlap_h = 8 + elif isinstance(overlap, int): + overlap_w = overlap_h = overlap + else: + overlap_w, overlap_h = overlap + + multiple = 1 + + (tile_w, tile_h), (overlap_w, overlap_h) = calc_tilesize( + tiles=tiles, tilesize=tilesize, + width=clip.width, height=clip.height, + multiple=multiple, + overlap_w=overlap_w, overlap_h=overlap_h + ) + + if tile_w % multiple != 0 or tile_h % multiple != 0: + raise ValueError( + f'{func_name}: tile size must be divisible by {multiple} ({tile_w}, {tile_h})' + ) + + backend = init_backend( + backend=backend, + trt_opt_shapes=(tile_w, tile_h) + ) + + model_name = tuple(ArtCNNModel.__members__)[model] + + network_path = os.path.join( + models_path, + "ArtCNN", + f"{model_name}.onnx" + ) + + if model in range(4, 6): + if clip.format.bits_per_sample == 16: + clip = core.akarin.Expr(clip, ["", "x 0.5 +"]) + else: + clip = core.std.Expr(clip, ["", "x 0.5 +"]) + + clip = inference_with_fallback( + clips=[clip], network_path=network_path, + overlap=(overlap_w, overlap_h), tilesize=(tile_w, tile_h), + backend=backend + ) + + if model in range(4, 6): + clip = core.std.ShufflePlanes(clip, [0, 1, 2], vs.YUV) + + if clip.format.bits_per_sample == 16: + clip = core.akarin.Expr(clip, ["", "x 0.5 -"]) + else: + clip = core.std.Expr(clip, ["", "x 0.5 -"]) + + return clip + + def get_engine_path( network_path: str, min_shapes: typing.Tuple[int, int],