Skip to content

Commit

Permalink
[tools/model_explorer_circle] Create a map for tensors and its source (
Browse files Browse the repository at this point in the history
…#14347)

It creates a mapping between tensors and its source node.
It will be used to represent connections between graph nodes.

ONE-DCO-1.0-Signed-off-by: Jonghwa Lee <[email protected]>
  • Loading branch information
batcheu authored Nov 22, 2024
1 parent c72c89a commit e9cbec3
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion tools/model_explorer_circle/src/circle_adapter/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
"""Model Explorer adapter for Circle models."""

from typing import Dict
from typing import Dict, Optional
from model_explorer import Adapter, AdapterMetadata, ModelExplorerGraphs, graph_builder
from circle_adapter import circle_schema_generated as circle_schema

Expand All @@ -34,6 +34,8 @@ def __init__(self):
v: k
for k, v in circle_schema.BuiltinOperator.__dict__.items()
}
# tensor_id -> node_id/output_id
self.map_tensor_to_source = {}

def load_model(self, model_path: str) -> None:
"""Load the model from the given path."""
Expand All @@ -46,6 +48,14 @@ def opcode_to_name(self, opcode: int) -> str:
"""Convert the opcode to its name."""
return self.dict_opcode_to_name[opcode]

def set_source_of(self, tensor_id: int, source_id: int, output_id: int) -> None:
"""Set the source of the tensor."""
self.map_tensor_to_source[tensor_id] = f'{source_id}/{output_id}'

def get_source_of(self, tensor_id: int) -> Optional[str]:
"""Get the source of the tensor."""
return self.map_tensor_to_source.get(tensor_id)

def build_graph(self, me_graph: graph_builder.Graph) -> None:
"""Build the graph using the model."""

Expand All @@ -58,6 +68,15 @@ def build_graph(self, me_graph: graph_builder.Graph) -> None:
namespace="GraphInputs")
me_graph.nodes.append(me_node)

# Map source and output tensors of GraphInputs
for i, tensor_id in enumerate(sub_graph.inputs):
self.set_source_of(tensor_id=tensor_id, source_id=input_id, output_id=i)

# Map source and output tensors of operators
for op_id, op in enumerate(sub_graph.operators):
for i, tensor_id in enumerate(op.outputs):
self.set_source_of(tensor_id=tensor_id, source_id=op_id, output_id=i)

# Create operator nodes
for idx, op in enumerate(sub_graph.operators):
name = self.opcode_to_name(
Expand Down

0 comments on commit e9cbec3

Please sign in to comment.