Skip to content

Commit

Permalink
scripts/vsmlrt.py: fix parsing of trt version
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Mar 28, 2024
1 parent 66146a2 commit 0c4d8ca
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions scripts/vsmlrt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "3.15.54"
__version__ = "3.15.55"

__all__ = [
"Backend", "BackendV2",
Expand Down Expand Up @@ -1456,7 +1456,7 @@ def trtexec(
) -> str:

# tensort runtime version, e.g. 8401 => 8.4.1
trt_version = int(core.trt.Version()["tensorrt_version"])
trt_version = parse_trt_version(int(core.trt.Version()["tensorrt_version"]))

if isinstance(opt_shapes, int):
opt_shapes = (opt_shapes, opt_shapes)
Expand Down Expand Up @@ -1517,7 +1517,7 @@ def trtexec(
]

if workspace is not None:
if trt_version >= 8400:
if trt_version >= (8, 4, 0):
args.append(f"--memPoolSize=workspace:{workspace}")
else:
args.append(f"--workspace{workspace}")
Expand All @@ -1542,9 +1542,9 @@ def trtexec(
disabled_tactic_sources.extend(["-CUBLAS", "-CUBLAS_LT"])
if not use_cudnn:
disabled_tactic_sources.append("-CUDNN")
if not use_edge_mask_convolutions and trt_version >= 8401:
if not use_edge_mask_convolutions and trt_version >= (8, 4, 1):
disabled_tactic_sources.append("-EDGE_MASK_CONVOLUTIONS")
if not use_jit_convolutions and trt_version >= 8500:
if not use_jit_convolutions and trt_version >= (8, 5, 0):
disabled_tactic_sources.append("-JIT_CONVOLUTIONS")
if disabled_tactic_sources:
args.append(f"--tacticSources={','.join(disabled_tactic_sources)}")
Expand All @@ -1563,8 +1563,8 @@ def trtexec(
if not tf32:
args.append("--noTF32")

if heuristic and trt_version >= 8500 and core.trt.DeviceProperties(device_id)["major"] >= 8:
if trt_version < 8600:
if heuristic and trt_version >= (8, 5, 0) and core.trt.DeviceProperties(device_id)["major"] >= 8:
if trt_version < (8, 6, 0):
args.append("--heuristic")
else:
builder_optimization_level = 2
Expand All @@ -1574,11 +1574,11 @@ def trtexec(
"--outputIOFormats=fp32:chw" if output_format == 0 else "--outputIOFormats=fp16:chw"
])

if faster_dynamic_shapes and not static_shape and 8500 <= trt_version < 8600:
if faster_dynamic_shapes and not static_shape and (8, 5, 0) <= trt_version < (8, 6, 0):
args.append("--preview=+fasterDynamicShapes0805")

if force_fp16:
if trt_version >= 8401:
if trt_version >= (8, 4, 1):
args.extend([
"--layerPrecisions=*:fp16",
"--layerOutputTypes=*:fp16",
Expand All @@ -1587,7 +1587,7 @@ def trtexec(
else:
raise ValueError('"force_fp16" is not available')

if trt_version >= 8600:
if trt_version >= (8, 6, 0):
args.append(f"--builderOptimizationLevel={builder_optimization_level}")

args.extend(custom_args)
Expand Down Expand Up @@ -2238,3 +2238,11 @@ def fmtc_resample(clip: vs.VideoNode, **kwargs) -> vs.VideoNode:
clip = core.resize.Point(clip, format=clip_org.format)

return clip


def parse_trt_version(version: int) -> typing.Tuple[int, int, int]:
# before trt 10
if version < 10000:
return version // 1000, (version // 100) % 10, version % 100
else:
return version // 10000, (version // 100) % 100, version % 100

0 comments on commit 0c4d8ca

Please sign in to comment.