Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CLI script exporting CLIP Toy model IREE test data #672

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions sharktank/sharktank/models/clip/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
CLIPEncoderLayer as HfCLIPEncoderLayer,
CLIPEncoder as HfCLIPEncoder,
)
from os import PathLike
import torch

from ...types.theta import Theta, Dataset, torch_module_to_theta
from ...layers.configs import ClipTextConfig
from ...utils.typing import AnyPath
from .clip import ClipTextModel
from iree.turbine.aot import FxProgramsBuilder, export

Expand Down Expand Up @@ -50,9 +50,14 @@ def clip_text_model_to_dataset(model: ClipTextModel) -> Dataset:
return Dataset(properties=model.config.to_properties(), root_theta=model.theta)


def export_clip_text_model_iree_parameters(model: ClipTextModel, output_path: AnyPath):
dataset = clip_text_model_to_dataset(model)
dataset.save(output_path)


def export_clip_text_model_dataset_from_hugging_face(
model_or_name_or_path: Union[str, PathLike, transformers.CLIPTextModel],
output_path: Union[str, PathLike],
model_or_name_or_path: Union[AnyPath, transformers.CLIPTextModel],
output_path: AnyPath,
dtype: Optional[torch.dtype] = None,
):
if isinstance(model_or_name_or_path, transformers.CLIPTextModel):
Expand All @@ -67,7 +72,7 @@ def export_clip_text_model_dataset_from_hugging_face(


def export_clip_text_model_mlir(
model: Union[ClipTextModel, PathLike],
model: Union[ClipTextModel, AnyPath],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave this as PathLike and remove the AnyPath. Users should just make sure to convert the string before invoking. (Same as above)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed AnyPath.

batch_sizes: list[int],
mlir_output_path: str,
):
Expand Down Expand Up @@ -99,3 +104,17 @@ def _(

output = export(fxb, import_symbolic_shape_expressions=True)
output.save_mlir(mlir_output_path)


def export_clip_text_model_to_iree(
model: ClipTextModel,
batch_sizes: list[int],
mlir_output_path: AnyPath,
parameters_output_path: AnyPath,
):
export_clip_text_model_iree_parameters(model, parameters_output_path)
export_clip_text_model_mlir(
model=parameters_output_path,
batch_sizes=batch_sizes,
mlir_output_path=mlir_output_path,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from argparse import ArgumentParser
from typing import Optional

from .testing import export_clip_toy_text_model_default_iree_test_data


def main(args: Optional[list[str]] = None):
parser = ArgumentParser(
description=(
"Export test data for toy-sized CLIP text model."
" This program MLIR, parameters sample input and expected output."
" Exports float32 and bfloat16 model variants."
" The expected output is always in float32 precision."
)
)
parser.add_argument(
"--output-path-prefix", type=str, default=f"clip_toy_text_model"
)
args = parser.parse_args(args=args)
export_clip_toy_text_model_default_iree_test_data(args.output_path_prefix)


if __name__ == "__main__":
main()
155 changes: 151 additions & 4 deletions sharktank/sharktank/models/clip/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,161 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ...layers.configs.llm_configs import ClipTextConfig
from ...types.theta import Theta
from .export import hugging_face_clip_text_model_to_theta
import functools
import torch
from os import PathLike
from typing import Union, Optional
from copy import copy
from iree.turbine.aot.params import ParameterArchiveBuilder

from ...layers.configs.llm_configs import ClipTextConfig
from .clip import ClipTextModel
from ...types.theta import Theta, Dataset
from ...types.tensors import dtype_to_serialized_short_name
from ...utils.typing import AnyPath
from ...utils.io import save_tensor_as_irpa
from .export import (
clip_text_model_to_dataset,
hugging_face_clip_text_model_to_theta,
export_clip_text_model_to_iree,
)
from ...transforms.dataset import set_float_dtype


def clip_toy_text_model_config(dtype: torch.dtype) -> ClipTextConfig:
num_attention_heads = 5
vocab_size = 11
return ClipTextConfig(
vocab_size=vocab_size,
hidden_size=13 * num_attention_heads,
intermediate_size=7,
projection_dim=3,
num_attention_heads=num_attention_heads,
max_position_embeddings=17,
layer_norm_eps=1e-4,
num_hidden_layers=2,
bos_token_id=vocab_size - 2,
eos_token_id=vocab_size - 1,
dtype=dtype,
)


def export_clip_toy_text_model_default_iree_test_data(output_path_prefix: str):
# We want to always export the same without interfering with RNG for the rest of
# the program.
rng_state = torch.get_rng_state()
torch.random.manual_seed(12345)

reference_dtype = torch.float32
target_dtypes = [torch.float32, torch.bfloat16]
target_iree_parameters_output_paths = []
target_mlir_output_paths = []
batch_size = 4
for dtype in target_dtypes:
prefix = f"{output_path_prefix}_{dtype_to_serialized_short_name(dtype)}"
target_iree_parameters_output_paths.append(f"{prefix}_parameters.irpa")
target_mlir_output_paths.append(f"{prefix}.mlir")
call_prefix = f"{output_path_prefix}_forward_bs{batch_size}"
input_ids_output_path = f"{call_prefix}_arg0_input_ids.irpa"
expected_last_hidden_state_output_path = (
f"{call_prefix}_expected_result0_last_hidden_state_"
f"{dtype_to_serialized_short_name(reference_dtype)}.irpa"
)
export_clip_toy_text_model_iree_test_data(
reference_dtype=reference_dtype,
target_dtypes=target_dtypes,
batch_size=batch_size,
input_ids_output_path=input_ids_output_path,
expected_last_hidden_state_output_path=expected_last_hidden_state_output_path,
target_iree_parameters_output_paths=target_iree_parameters_output_paths,
target_mlir_output_paths=target_mlir_output_paths,
)

torch.set_rng_state(rng_state)


def export_clip_toy_text_model_iree_test_data(
reference_dtype: torch.dtype,
target_dtypes: list[torch.dtype],
batch_size: int,
target_iree_parameters_output_paths: list[AnyPath],
target_mlir_output_paths: list[AnyPath],
input_ids_output_path: AnyPath,
expected_last_hidden_state_output_path: AnyPath,
):
reference_config = clip_toy_text_model_config(reference_dtype)
input_ids = make_random_input_token_sequences(
batch_size=batch_size, config=reference_config
)
reference_theta = make_clip_text_model_random_theta(reference_config)
reference_model = ClipTextModel(theta=reference_theta, config=reference_config)
for i, (
target_dtype,
target_iree_parameters_output_path,
target_mlir_output_path,
) in enumerate(
zip(
target_dtypes,
target_iree_parameters_output_paths,
target_mlir_output_paths,
strict=True,
)
):
export_clip_text_model_iree_test_data(
reference_model=reference_model,
target_dtype=target_dtype,
input_ids=input_ids,
target_iree_parameters_output_path=target_iree_parameters_output_path,
target_mlir_output_path=target_mlir_output_path,
input_ids_output_path=input_ids_output_path if i == 0 else None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Split this across lines or put in an if conditional. The , between make the behavior extremely ambiguous.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

expected_last_hidden_state_output_path=expected_last_hidden_state_output_path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above. This looks like it should just be in an if condition rather than a oneliner

if i == 0
else None,
)


def export_clip_text_model_iree_test_data(
reference_model: ClipTextModel,
target_dtype: torch.dtype,
input_ids: torch.LongTensor,
target_mlir_output_path: AnyPath,
target_iree_parameters_output_path: AnyPath,
input_ids_output_path: Optional[AnyPath] = None,
expected_last_hidden_state_output_path: Optional[AnyPath] = None,
):
batch_size = input_ids.shape[0]
reference_dataset = clip_text_model_to_dataset(reference_model)
target_config = copy(reference_model.config)
target_config.dtype = target_dtype
target_dataset = Dataset(
root_theta=reference_dataset.root_theta.transform(
functools.partial(set_float_dtype, dtype=torch.bfloat16)
),
properties=target_config.to_properties(),
)
target_model = ClipTextModel(theta=target_dataset.root_theta, config=target_config)
export_clip_text_model_to_iree(
target_model,
batch_sizes=[batch_size],
mlir_output_path=target_mlir_output_path,
parameters_output_path=target_iree_parameters_output_path,
)

if input_ids_output_path is not None:
save_tensor_as_irpa(input_ids, input_ids_output_path)

if expected_last_hidden_state_output_path is None:
return

expected_last_hidden_state = reference_model(input_ids=input_ids)[
"last_hidden_state"
]
save_tensor_as_irpa(
expected_last_hidden_state, expected_last_hidden_state_output_path
)


def make_clip_text_model_random_theta(config: ClipTextConfig) -> Theta:
from transformers import CLIPTextConfig as HfCLIPTextConfig
from transformers import CLIPTextModel as HfCLIPTextModel

hf_config = config.to_hugging_face_clip_text_model_config()
Expand Down
26 changes: 23 additions & 3 deletions sharktank/sharktank/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from pathlib import Path
import torch

from iree.turbine.aot import (
ParameterArchiveBuilder,
)
from iree.turbine.aot import ParameterArchiveBuilder, ParameterArchive

from .typing import AnyPath


class ShardedArchiveBuilder(ParameterArchiveBuilder):
Expand Down Expand Up @@ -49,3 +50,22 @@ def path_for_rank(path: Path, rank: int):
/tmp/foobar.rank0.irpa
"""
return path.with_suffix(f".rank{rank}{path.suffix}")


def save_tensor_as_irpa(tensor: torch.Tensor, path: AnyPath):
"""Save a single tensor into an IRPA file."""
param_builder = ParameterArchiveBuilder()
param_builder.add_tensor("", tensor)
param_builder.save(path)


def load_irpa_as_tensor(tensor: torch.Tensor, path: AnyPath, **kwargs):
"""Load a tensor form an IRPA file that holds only one tensor."""
params = ParameterArchive(path, **kwargs)
items = params.items()
if len(items) != 1:
raise ValueError(
f'Too many items {len(items)} in IRPA file "{path}".'
" Only a single tensor was expected."
)
return items[0][1].as_tensor()
10 changes: 10 additions & 0 deletions sharktank/sharktank/utils/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Union
from os import PathLike

AnyPath = Union[str, PathLike]
44 changes: 22 additions & 22 deletions sharktank/tests/models/clip/clip_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from sharktank.models.clip.testing import (
make_random_input_token_sequences,
make_clip_text_model_random_theta,
export_clip_text_model_iree_test_data,
)
from sharktank.models.clip import (
ClipAttention,
Expand Down Expand Up @@ -96,6 +97,11 @@ def testSmokeExportLargeF32FromHuggingFace(self):
huggingface_repo_id, output_path
)

def testSmokeExportToyIreeTestData(self):
from sharktank.models.clip.export_toy_text_model_iree_test_data import main

main([f"--output-path-prefix={self.path_prefix}clip_toy_text_model"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use pythons path library to operate on this instead of string manipulations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I changed it to use a directory as a prefix.


@with_clip_data
def testCompareLargeIreeF32AgainstTorchEagerF32(self):
self.runTestCompareIreeAgainstPretrainedTorchEager(
Expand Down Expand Up @@ -147,30 +153,24 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens(
f"{self.path_prefix}{file_artifact_prefix_name}_{target_dtype_name}"
)

target_config = copy(reference_model.config)
target_config.dtype = target_dtype
reference_dataset = clip_text_model_to_dataset(reference_model)
target_dataset = Dataset(
root_theta=reference_dataset.root_theta.transform(
functools.partial(set_float_dtype, dtype=target_config.dtype)
),
properties=target_config.to_properties(),
)

parameters_path = f"{target_model_path_prefix}.irpa"
if not self.caching or not os.path.exists(parameters_path):
target_dataset.save(parameters_path)

dataset = Dataset.load(parameters_path)
target_config = ClipTextConfig.from_properties(dataset.properties)
input_args = OrderedDict([("input_ids", input_ids)])
batch_size = input_ids.shape[0]

mlir_path = f"{target_model_path_prefix}.mlir"
if not self.caching or not os.path.exists(mlir_path):
export_clip_text_model_mlir(
parameters_path, batch_sizes=[batch_size], mlir_output_path=mlir_path

if (
not self.caching
or not os.path.exists(mlir_path)
or not os.path.exists(parameters_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This conditional feels weird / wrong. Is it checking for defaults? Also it feels particularly weird for a test. Reconsider why this is needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the caching. It is useful to speedup iteration, but it does add clutter.

):
export_clip_text_model_iree_test_data(
reference_model=reference_model,
target_dtype=target_dtype,
input_ids=input_ids,
target_mlir_output_path=mlir_path,
target_iree_parameters_output_path=parameters_path,
)

iree_module_path = f"{target_model_path_prefix}.vmfb"
if not self.caching or not os.path.exists(iree_module_path):
iree.compiler.compile_file(
Expand Down Expand Up @@ -211,11 +211,11 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens(
for i in range(len(expected_outputs))
]

actual_last_hidden_states = actual_outputs[0]
expected_last_hidden_states = expected_outputs[0]
actual_last_hidden_state = actual_outputs[0]
expected_last_hidden_state = expected_outputs[0]

assert_text_encoder_state_close(
actual_last_hidden_states, expected_last_hidden_states, atol
actual_last_hidden_state, expected_last_hidden_state, atol
)

def runTestCompareRandomModelIreeAgainstTorch(
Expand Down
Loading