Skip to content

Commit

Permalink
scripts/vsmlrt.py: add support for artcnn models
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed May 9, 2024
1 parent 5570852 commit 6af7646
Showing 1 changed file with 101 additions and 1 deletion.
102 changes: 101 additions & 1 deletion scripts/vsmlrt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "3.20.12"
__version__ = "3.20.13"

__all__ = [
"Backend", "BackendV2",
Expand All @@ -11,6 +11,7 @@
"SAFA", "SAFAModel", "SAFAAdaptiveMode",
"SCUNet", "SCUNetModel",
"SwinIR", "SwinIRModel",
"ArtCNN", "ArtCNNModel",
"inference"
]

Expand Down Expand Up @@ -1622,6 +1623,105 @@ def SwinIR(
return clip


@enum.unique
class ArtCNNModel(enum.IntEnum):
ArtCNN_C4F32 = 0
ArtCNN_C4F32_DS = 1
ArtCNN_C16F64 = 2
ArtCNN_C16F64_DS = 3
ArtCNN_C4F32_Chroma = 4
ArtCNN_C16F64_Chroma = 5


def ArtCNN(
clip: vs.VideoNode,
tiles: typing.Optional[typing.Union[int, typing.Tuple[int, int]]] = None,
tilesize: typing.Optional[typing.Union[int, typing.Tuple[int, int]]] = None,
overlap: typing.Optional[typing.Union[int, typing.Tuple[int, int]]] = None,
model: ArtCNNModel = ArtCNNModel.ArtCNN_C16F64,
backend: backendT = Backend.OV_CPU()
) -> vs.VideoNode:
""" ArtCNN (https://github.com/Artoriuz/ArtCNN) """

func_name = "vsmlrt.ArtCNN"

if not isinstance(clip, vs.VideoNode):
raise TypeError(f'{func_name}: "clip" must be a clip!')

if clip.format.sample_type != vs.FLOAT or clip.format.bits_per_sample not in [16, 32]:
raise ValueError(f"{func_name}: only constant format 16/32 bit float input supported")

if not isinstance(model, int) or model not in ArtCNNModel.__members__.values():
raise ValueError(f'{func_name}: invalid "model"')

if model in range(4, 6):
if clip.format.color_family != vs.YUV:
raise ValueError(f'{func_name}: "clip" must be of YUV color family')
if clip.format.subsampling_h != 0 or clip.format.subsampling_w != 0:
raise ValueError(
f'{func_name}: "clip" must be without subsampling!'
'Bilinear upsampling is recommended.'
)
elif clip.format.color_family != vs.GRAY:
raise ValueError(f'{func_name}: "clip" must be of GRAY color family')

if overlap is None:
overlap_w = overlap_h = 8
elif isinstance(overlap, int):
overlap_w = overlap_h = overlap
else:
overlap_w, overlap_h = overlap

multiple = 1

(tile_w, tile_h), (overlap_w, overlap_h) = calc_tilesize(
tiles=tiles, tilesize=tilesize,
width=clip.width, height=clip.height,
multiple=multiple,
overlap_w=overlap_w, overlap_h=overlap_h
)

if tile_w % multiple != 0 or tile_h % multiple != 0:
raise ValueError(
f'{func_name}: tile size must be divisible by {multiple} ({tile_w}, {tile_h})'
)

backend = init_backend(
backend=backend,
trt_opt_shapes=(tile_w, tile_h)
)

model_name = tuple(ArtCNNModel.__members__)[model]

network_path = os.path.join(
models_path,
"ArtCNN",
f"{model_name}.onnx"
)

if model in range(4, 6):
if clip.format.bits_per_sample == 16:
clip = core.akarin.Expr(clip, ["", "x 0.5 +"])
else:
clip = core.std.Expr(clip, ["", "x 0.5 +"])

clip = inference_with_fallback(
clips=[clip], network_path=network_path,
overlap=(overlap_w, overlap_h), tilesize=(tile_w, tile_h),
backend=backend
)

if model in range(4, 6):
clip = core.std.ShufflePlanes(clip, [0, 1, 2], vs.YUV)

if clip.format.bits_per_sample == 16:
clip = core.akarin.Expr(clip, ["", "x 0.5 -"])
else:
clip = core.std.Expr(clip, ["", "x 0.5 -"])

return clip


def get_engine_path(
network_path: str,
min_shapes: typing.Tuple[int, int],
Expand Down

0 comments on commit 6af7646

Please sign in to comment.