Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show number of parameters when onnx files are being analyzed #125

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 6 additions & 18 deletions src/turnkeyml/analyze/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,24 +766,12 @@ def forward_spy(*args, **kwargs):
model_info = tracer_args.models_found[model_hash]

if invocation_hash not in model_info.unique_invocations:
model_info.unique_invocations[invocation_hash] = (
status.UniqueInvocationInfo(
name=model_info.name,
script_name=model_info.script_name,
file=model_info.file,
line=model_info.line,
params=model_info.params,
depth=model_info.depth,
build_model=model_info.build_model,
model_type=model_info.model_type,
model_class=type(model_info.model),
invocation_hash=invocation_hash,
hash=model_info.hash,
is_target=invocation_hash in tracer_args.targets
or len(tracer_args.targets) == 0,
input_shapes=input_shapes,
parent_hash=parent_invocation_hash,
)
model_info.add_unique_invocation(
invocation_hash=invocation_hash,
is_target=invocation_hash in tracer_args.targets
or len(tracer_args.targets) == 0,
input_shapes=input_shapes,
parent_hash=parent_invocation_hash,
)
model_info.last_unique_invocation_executed = invocation_hash

Expand Down
29 changes: 29 additions & 0 deletions src/turnkeyml/analyze/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,35 @@ class ModelInfo(BasicInfo):
def __post_init__(self):
self.params = analyze_model.count_parameters(self.model, self.model_type)

def add_unique_invocation(
self,
invocation_hash: int,
is_target: bool,
input_shapes: Dict,
parent_hash: Union[str, None] = None,
executed: int = 0,
):
model_class = (
type(self.model) if self.model_type == build.ModelType.PYTORCH else None
)
self.unique_invocations[invocation_hash] = UniqueInvocationInfo(
name=self.name,
script_name=self.script_name,
file=self.file,
line=self.line,
params=self.params,
depth=self.depth,
build_model=self.build_model,
model_type=self.model_type,
model_class=model_class,
invocation_hash=invocation_hash,
hash=self.hash,
is_target=is_target,
input_shapes=input_shapes,
parent_hash=parent_hash,
executed=executed,
)


def update(
models_found: Dict[str, ModelInfo],
Expand Down
35 changes: 14 additions & 21 deletions src/turnkeyml/files_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
explore_invocation,
get_model_hash,
)
from turnkeyml.analyze.status import ModelInfo, UniqueInvocationInfo, Verbosity
from turnkeyml.analyze.status import ModelInfo, Verbosity
import turnkeyml.common.build as build
import turnkeyml.build.onnx_helpers as onnx_helpers

Expand Down Expand Up @@ -374,24 +374,6 @@ def benchmark_files(
onnx_inputs = onnx_helpers.dummy_inputs(file_path_absolute)
input_shapes = {key: value.shape for key, value in onnx_inputs.items()}

# Create the UniqueInvocationInfo
# - execute=1 is required or else the ONNX model will be
# skipped in later stages of evaluation
# - is_target=True is required or else traceback wont be printed for
# in the event of any errors
# - Most other values can be left as default
invocation_info = UniqueInvocationInfo(
name=onnx_name,
script_name=onnx_name,
file=file_path_absolute,
build_model=not build_only,
model_type=build.ModelType.ONNX_FILE,
executed=1,
input_shapes=input_shapes,
hash=onnx_hash,
is_target=True,
)

# Create the ModelInfo
model_info = ModelInfo(
model=file_path_absolute,
Expand All @@ -400,17 +382,28 @@ def benchmark_files(
file=file_path_absolute,
build_model=not build_only,
model_type=build.ModelType.ONNX_FILE,
unique_invocations={onnx_hash: invocation_info},
hash=onnx_hash,
)

# Add UniqueInvocationInfo
# - is_target=True is required or else traceback wont be printed for
# in the event of any errors
# - execute=1 is required or else the ONNX model will be
# skipped in later stages of evaluation
model_info.add_unique_invocation(
invocation_hash=onnx_hash,
is_target=True,
input_shapes=input_shapes,
executed=1,
)

# Begin evaluating the ONNX model
tracer_args.script_name = onnx_name
tracer_args.models_found[tracer_args.script_name] = model_info
explore_invocation(
model_inputs=onnx_inputs,
model_info=model_info,
invocation_info=invocation_info,
invocation_info=model_info.unique_invocations[onnx_hash],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the code could be a bit simpler if it was ModelInfo.add_unique_invocation() -> UniqueInvocationInfo (ie, the function returned the new UniqueInvocationInfo instance).

That way you wouldn't have to look it up in the dictionary a few lines of code later.

tracer_args=tracer_args,
)
models_found = tracer_args.models_found
Expand Down
Loading