diff --git a/src/turnkeyml/analyze/script.py b/src/turnkeyml/analyze/script.py index fd1f8054..2de3b07b 100644 --- a/src/turnkeyml/analyze/script.py +++ b/src/turnkeyml/analyze/script.py @@ -766,24 +766,12 @@ def forward_spy(*args, **kwargs): model_info = tracer_args.models_found[model_hash] if invocation_hash not in model_info.unique_invocations: - model_info.unique_invocations[invocation_hash] = ( - status.UniqueInvocationInfo( - name=model_info.name, - script_name=model_info.script_name, - file=model_info.file, - line=model_info.line, - params=model_info.params, - depth=model_info.depth, - build_model=model_info.build_model, - model_type=model_info.model_type, - model_class=type(model_info.model), - invocation_hash=invocation_hash, - hash=model_info.hash, - is_target=invocation_hash in tracer_args.targets - or len(tracer_args.targets) == 0, - input_shapes=input_shapes, - parent_hash=parent_invocation_hash, - ) + model_info.add_unique_invocation( + invocation_hash=invocation_hash, + is_target=invocation_hash in tracer_args.targets + or len(tracer_args.targets) == 0, + input_shapes=input_shapes, + parent_hash=parent_invocation_hash, ) model_info.last_unique_invocation_executed = invocation_hash diff --git a/src/turnkeyml/analyze/status.py b/src/turnkeyml/analyze/status.py index 2abc9888..c31cdb35 100644 --- a/src/turnkeyml/analyze/status.py +++ b/src/turnkeyml/analyze/status.py @@ -327,6 +327,35 @@ class ModelInfo(BasicInfo): def __post_init__(self): self.params = analyze_model.count_parameters(self.model, self.model_type) + def add_unique_invocation( + self, + invocation_hash: int, + is_target: bool, + input_shapes: Dict, + parent_hash: Union[str, None] = None, + executed: int = 0, + ): + model_class = ( + type(self.model) if self.model_type == build.ModelType.PYTORCH else None + ) + self.unique_invocations[invocation_hash] = UniqueInvocationInfo( + name=self.name, + script_name=self.script_name, + file=self.file, + line=self.line, + params=self.params, + depth=self.depth, + build_model=self.build_model, + model_type=self.model_type, + model_class=model_class, + invocation_hash=invocation_hash, + hash=self.hash, + is_target=is_target, + input_shapes=input_shapes, + parent_hash=parent_hash, + executed=executed, + ) + def update( models_found: Dict[str, ModelInfo], diff --git a/src/turnkeyml/files_api.py b/src/turnkeyml/files_api.py index 4ee56d91..211de65a 100644 --- a/src/turnkeyml/files_api.py +++ b/src/turnkeyml/files_api.py @@ -20,7 +20,7 @@ explore_invocation, get_model_hash, ) -from turnkeyml.analyze.status import ModelInfo, UniqueInvocationInfo, Verbosity +from turnkeyml.analyze.status import ModelInfo, Verbosity import turnkeyml.common.build as build import turnkeyml.build.onnx_helpers as onnx_helpers @@ -374,24 +374,6 @@ def benchmark_files( onnx_inputs = onnx_helpers.dummy_inputs(file_path_absolute) input_shapes = {key: value.shape for key, value in onnx_inputs.items()} - # Create the UniqueInvocationInfo - # - execute=1 is required or else the ONNX model will be - # skipped in later stages of evaluation - # - is_target=True is required or else traceback wont be printed for - # in the event of any errors - # - Most other values can be left as default - invocation_info = UniqueInvocationInfo( - name=onnx_name, - script_name=onnx_name, - file=file_path_absolute, - build_model=not build_only, - model_type=build.ModelType.ONNX_FILE, - executed=1, - input_shapes=input_shapes, - hash=onnx_hash, - is_target=True, - ) - # Create the ModelInfo model_info = ModelInfo( model=file_path_absolute, @@ -400,17 +382,28 @@ def benchmark_files( file=file_path_absolute, build_model=not build_only, model_type=build.ModelType.ONNX_FILE, - unique_invocations={onnx_hash: invocation_info}, hash=onnx_hash, ) + # Add UniqueInvocationInfo + # - is_target=True is required or else traceback wont be printed for + # in the event of any errors + # - execute=1 is required or else the ONNX model will be + # skipped in later stages of evaluation + model_info.add_unique_invocation( + invocation_hash=onnx_hash, + is_target=True, + input_shapes=input_shapes, + executed=1, + ) + # Begin evaluating the ONNX model tracer_args.script_name = onnx_name tracer_args.models_found[tracer_args.script_name] = model_info explore_invocation( model_inputs=onnx_inputs, model_info=model_info, - invocation_info=invocation_info, + invocation_info=model_info.unique_invocations[onnx_hash], tracer_args=tracer_args, ) models_found = tracer_args.models_found