Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add split_assignments to SplitModel pass #1610

Merged
merged 7 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions olive/passes/onnx/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ class SplitModel(Pass):
@classmethod
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]:
return {
"split_assignments": PassConfigParam(
type_=str,
default_value=None,
description=(
"Set split assignments in the format of name1=0;name2=1 etc."
" Overwrite the one from CaptureSplitInfo pass."
),
),
**get_external_data_config(),
}

Expand All @@ -36,18 +44,20 @@ def _run_for_config(
) -> CompositeModelHandler:
model_proto = model.load_model()

split_assignments = None
for metadata_prop in model_proto.metadata_props:
if metadata_prop.key == "split_assignments":
split_assignments = {
key: int(value)
for key, value in (assignment.split("=") for assignment in metadata_prop.value.split(";"))
}
break
split_assignments = config["split_assignments"]
if split_assignments is None:
for metadata_prop in model_proto.metadata_props:
if metadata_prop.key == "split_assignments":
split_assignments = metadata_prop.value
break
# TODO(jambayk): Should we allow split assignments in the model attributes too?
if not split_assignments:
raise ValueError("No split assignments found in the model metadata")

split_assignments = {
key: int(value) for key, value in (assignment.split("=") for assignment in split_assignments.split(";"))
}

# TODO(jambayk): Make this more generic, for now only assume transformers layers are split
# so depth of namespace is same for all split assignments
num_splits = len(np.unique(list(split_assignments.values())))
Expand Down
57 changes: 57 additions & 0 deletions test/unit_test/passes/onnx/test_split_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import pytest
import torch

from olive.hardware import AcceleratorSpec
from olive.model import CompositeModelHandler, HfModelHandler, ONNXModelHandler
Expand Down Expand Up @@ -141,3 +142,59 @@

# all non model outputs must be used between the splits
assert (seen_outputs - used_outputs) == model_outputs


class CustomModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.before_layer = torch.nn.Linear(2, 4)
self.layers = torch.nn.ModuleList([torch.nn.Linear(4, 4) for _ in range(2)])
self.after_layer = torch.nn.Linear(4, 2)

def forward(self, x):
x = self.before_layer(x)
x = self.layers[0](x) + self.layers[1](x)
return self.after_layer(x)


@pytest.mark.parametrize(
("split_assignments", "split_mid_io"),
[
(
# split vertically
"layers.0=0;layers.1=1",
["/before_layer/Gemm_output_0", "/layers.0/Gemm_output_0"],
),
(
# split horizontally
"before_layer=0;layers.0=1;layers.1=1",
["/before_layer/Gemm_output_0"],
),
],
)
def test_split_model_split_assignments(split_assignments, split_mid_io, tmp_path):
config = {
"split_assignments": split_assignments,
}
p = create_pass_from_dict(SplitModel, config, disable_search=True)

dummy_input = torch.randn(1, 2)
input_model_path = tmp_path / "input_model.onnx"
torch.onnx.export(CustomModel(), dummy_input, input_model_path, input_names=["input"], output_names=["output"])
input_model = ONNXModelHandler(input_model_path)

out = p.run(input_model, str(tmp_path))

assert len(out.model_component_names) == 2
assert out.model_component_names[0] == "split_0"
assert out.model_component_names[1] == "split_1"

i = 0
for model in out.model_components:
if i == 0:
assert model.io_config["input_names"] == ["input"]
assert model.io_config["output_names"] == split_mid_io
elif i == 1:
assert model.io_config["input_names"] == split_mid_io
assert model.io_config["output_names"] == ["output"]
i += 1
Loading