Skip to content

Commit

Permalink
Add CLI script exporting CLIP Toy model IREE test data
Browse files Browse the repository at this point in the history
This is required to have an easy way of exporting test data that will be
used in IREE to guard against regressions.
E.g.
```
python -m sharktank.models.clip.export_toy_text_model_iree_test_data \
  --output-path-prefix=clip_toy_text_model
```

Refactor some of the existing tests to reuse the new export logic.
  • Loading branch information
sogartar committed Dec 10, 2024
1 parent 36c0859 commit b71e906
Show file tree
Hide file tree
Showing 6 changed files with 259 additions and 33 deletions.
27 changes: 23 additions & 4 deletions sharktank/sharktank/models/clip/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
CLIPEncoderLayer as HfCLIPEncoderLayer,
CLIPEncoder as HfCLIPEncoder,
)
from os import PathLike
import torch

from ...types.theta import Theta, Dataset, torch_module_to_theta
from ...layers.configs import ClipTextConfig
from ...utils.typing import AnyPath
from .clip import ClipTextModel
from iree.turbine.aot import FxProgramsBuilder, export

Expand Down Expand Up @@ -50,9 +50,14 @@ def clip_text_model_to_dataset(model: ClipTextModel) -> Dataset:
return Dataset(properties=model.config.to_properties(), root_theta=model.theta)


def export_clip_text_model_iree_parameters(model: ClipTextModel, output_path: AnyPath):
dataset = clip_text_model_to_dataset(model)
dataset.save(output_path)


def export_clip_text_model_dataset_from_hugging_face(
model_or_name_or_path: Union[str, PathLike, transformers.CLIPTextModel],
output_path: Union[str, PathLike],
model_or_name_or_path: Union[AnyPath, transformers.CLIPTextModel],
output_path: AnyPath,
dtype: Optional[torch.dtype] = None,
):
if isinstance(model_or_name_or_path, transformers.CLIPTextModel):
Expand All @@ -67,7 +72,7 @@ def export_clip_text_model_dataset_from_hugging_face(


def export_clip_text_model_mlir(
model: Union[ClipTextModel, PathLike],
model: Union[ClipTextModel, AnyPath],
batch_sizes: list[int],
mlir_output_path: str,
):
Expand Down Expand Up @@ -99,3 +104,17 @@ def _(

output = export(fxb, import_symbolic_shape_expressions=True)
output.save_mlir(mlir_output_path)


def export_clip_text_model_to_iree(
model: ClipTextModel,
batch_sizes: list[int],
mlir_output_path: AnyPath,
parameters_output_path: AnyPath,
):
export_clip_text_model_iree_parameters(model, parameters_output_path)
export_clip_text_model_mlir(
model=parameters_output_path,
batch_sizes=batch_sizes,
mlir_output_path=mlir_output_path,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from argparse import ArgumentParser
from typing import Optional

from .testing import export_clip_toy_text_model_default_iree_test_data


def main(args: Optional[list[str]] = None):
parser = ArgumentParser(
description=(
"Export test data for toy-sized CLIP text model."
" This program MLIR, parameters sample input and expected output."
" Exports float32 and bfloat16 model variants."
" The expected output is always in float32 precision."
)
)
parser.add_argument(
"--output-path-prefix", type=str, default=f"clip_toy_text_model"
)
args = parser.parse_args(args=args)
export_clip_toy_text_model_default_iree_test_data(args.output_path_prefix)


if __name__ == "__main__":
main()
155 changes: 151 additions & 4 deletions sharktank/sharktank/models/clip/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,161 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ...layers.configs.llm_configs import ClipTextConfig
from ...types.theta import Theta
from .export import hugging_face_clip_text_model_to_theta
import functools
import torch
from os import PathLike
from typing import Union, Optional
from copy import copy
from iree.turbine.aot.params import ParameterArchiveBuilder

from ...layers.configs.llm_configs import ClipTextConfig
from .clip import ClipTextModel
from ...types.theta import Theta, Dataset
from ...types.tensors import dtype_to_serialized_short_name
from ...utils.typing import AnyPath
from ...utils.io import save_tensor_as_irpa
from .export import (
clip_text_model_to_dataset,
hugging_face_clip_text_model_to_theta,
export_clip_text_model_to_iree,
)
from ...transforms.dataset import set_float_dtype


def clip_toy_text_model_config(dtype: torch.dtype) -> ClipTextConfig:
num_attention_heads = 5
vocab_size = 11
return ClipTextConfig(
vocab_size=vocab_size,
hidden_size=13 * num_attention_heads,
intermediate_size=7,
projection_dim=3,
num_attention_heads=num_attention_heads,
max_position_embeddings=17,
layer_norm_eps=1e-4,
num_hidden_layers=2,
bos_token_id=vocab_size - 2,
eos_token_id=vocab_size - 1,
dtype=dtype,
)


def export_clip_toy_text_model_default_iree_test_data(output_path_prefix: str):
# We want to always export the same without interfering with RNG for the rest of
# the program.
rng_state = torch.get_rng_state()
torch.random.manual_seed(12345)

reference_dtype = torch.float32
target_dtypes = [torch.float32, torch.bfloat16]
target_iree_parameters_output_paths = []
target_mlir_output_paths = []
batch_size = 4
for dtype in target_dtypes:
prefix = f"{output_path_prefix}_{dtype_to_serialized_short_name(dtype)}"
target_iree_parameters_output_paths.append(f"{prefix}_parameters.irpa")
target_mlir_output_paths.append(f"{prefix}.mlir")
call_prefix = f"{output_path_prefix}_forward_bs{batch_size}"
input_ids_output_path = f"{call_prefix}_arg0_input_ids.irpa"
expected_last_hidden_state_output_path = (
f"{call_prefix}_expected_result0_last_hidden_state_"
f"{dtype_to_serialized_short_name(reference_dtype)}.irpa"
)
export_clip_toy_text_model_iree_test_data(
reference_dtype=reference_dtype,
target_dtypes=target_dtypes,
batch_size=batch_size,
input_ids_output_path=input_ids_output_path,
expected_last_hidden_state_output_path=expected_last_hidden_state_output_path,
target_iree_parameters_output_paths=target_iree_parameters_output_paths,
target_mlir_output_paths=target_mlir_output_paths,
)

torch.set_rng_state(rng_state)


def export_clip_toy_text_model_iree_test_data(
reference_dtype: torch.dtype,
target_dtypes: list[torch.dtype],
batch_size: int,
target_iree_parameters_output_paths: list[AnyPath],
target_mlir_output_paths: list[AnyPath],
input_ids_output_path: AnyPath,
expected_last_hidden_state_output_path: AnyPath,
):
reference_config = clip_toy_text_model_config(reference_dtype)
input_ids = make_random_input_token_sequences(
batch_size=batch_size, config=reference_config
)
reference_theta = make_clip_text_model_random_theta(reference_config)
reference_model = ClipTextModel(theta=reference_theta, config=reference_config)
for i, (
target_dtype,
target_iree_parameters_output_path,
target_mlir_output_path,
) in enumerate(
zip(
target_dtypes,
target_iree_parameters_output_paths,
target_mlir_output_paths,
strict=True,
)
):
export_clip_text_model_iree_test_data(
reference_model=reference_model,
target_dtype=target_dtype,
input_ids=input_ids,
target_iree_parameters_output_path=target_iree_parameters_output_path,
target_mlir_output_path=target_mlir_output_path,
input_ids_output_path=input_ids_output_path if i == 0 else None,
expected_last_hidden_state_output_path=expected_last_hidden_state_output_path
if i == 0
else None,
)


def export_clip_text_model_iree_test_data(
reference_model: ClipTextModel,
target_dtype: torch.dtype,
input_ids: torch.LongTensor,
target_mlir_output_path: AnyPath,
target_iree_parameters_output_path: AnyPath,
input_ids_output_path: Optional[AnyPath] = None,
expected_last_hidden_state_output_path: Optional[AnyPath] = None,
):
batch_size = input_ids.shape[0]
reference_dataset = clip_text_model_to_dataset(reference_model)
target_config = copy(reference_model.config)
target_config.dtype = target_dtype
target_dataset = Dataset(
root_theta=reference_dataset.root_theta.transform(
functools.partial(set_float_dtype, dtype=torch.bfloat16)
),
properties=target_config.to_properties(),
)
target_model = ClipTextModel(theta=target_dataset.root_theta, config=target_config)
export_clip_text_model_to_iree(
target_model,
batch_sizes=[batch_size],
mlir_output_path=target_mlir_output_path,
parameters_output_path=target_iree_parameters_output_path,
)

if input_ids_output_path is not None:
save_tensor_as_irpa(input_ids, input_ids_output_path)

if expected_last_hidden_state_output_path is None:
return

expected_last_hidden_state = reference_model(input_ids=input_ids)[
"last_hidden_state"
]
save_tensor_as_irpa(
expected_last_hidden_state, expected_last_hidden_state_output_path
)


def make_clip_text_model_random_theta(config: ClipTextConfig) -> Theta:
from transformers import CLIPTextConfig as HfCLIPTextConfig
from transformers import CLIPTextModel as HfCLIPTextModel

hf_config = config.to_hugging_face_clip_text_model_config()
Expand Down
26 changes: 23 additions & 3 deletions sharktank/sharktank/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from pathlib import Path
import torch

from iree.turbine.aot import (
ParameterArchiveBuilder,
)
from iree.turbine.aot import ParameterArchiveBuilder, ParameterArchive

from .typing import AnyPath


class ShardedArchiveBuilder(ParameterArchiveBuilder):
Expand Down Expand Up @@ -49,3 +50,22 @@ def path_for_rank(path: Path, rank: int):
/tmp/foobar.rank0.irpa
"""
return path.with_suffix(f".rank{rank}{path.suffix}")


def save_tensor_as_irpa(tensor: torch.Tensor, path: AnyPath):
"""Save a single tensor into an IRPA file."""
param_builder = ParameterArchiveBuilder()
param_builder.add_tensor("", tensor)
param_builder.save(path)


def load_irpa_as_tensor(tensor: torch.Tensor, path: AnyPath, **kwargs):
"""Load a tensor form an IRPA file that holds only one tensor."""
params = ParameterArchive(path, **kwargs)
items = params.items()
if len(items) != 1:
raise ValueError(
f'Too many items {len(items)} in IRPA file "{path}".'
" Only a single tensor was expected."
)
return items[0][1].as_tensor()
10 changes: 10 additions & 0 deletions sharktank/sharktank/utils/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Union
from os import PathLike

AnyPath = Union[str, PathLike]
44 changes: 22 additions & 22 deletions sharktank/tests/models/clip/clip_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from sharktank.models.clip.testing import (
make_random_input_token_sequences,
make_clip_text_model_random_theta,
export_clip_text_model_iree_test_data,
)
from sharktank.models.clip import (
ClipAttention,
Expand Down Expand Up @@ -96,6 +97,11 @@ def testSmokeExportLargeF32FromHuggingFace(self):
huggingface_repo_id, output_path
)

def testSmokeExportToyIreeTestData(self):
from sharktank.models.clip.export_toy_text_model_iree_test_data import main

main([f"--output-path-prefix={self.path_prefix}clip_toy_text_model"])

@with_clip_data
def testCompareLargeIreeF32AgainstTorchEagerF32(self):
self.runTestCompareIreeAgainstPretrainedTorchEager(
Expand Down Expand Up @@ -147,30 +153,24 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens(
f"{self.path_prefix}{file_artifact_prefix_name}_{target_dtype_name}"
)

target_config = copy(reference_model.config)
target_config.dtype = target_dtype
reference_dataset = clip_text_model_to_dataset(reference_model)
target_dataset = Dataset(
root_theta=reference_dataset.root_theta.transform(
functools.partial(set_float_dtype, dtype=target_config.dtype)
),
properties=target_config.to_properties(),
)

parameters_path = f"{target_model_path_prefix}.irpa"
if not self.caching or not os.path.exists(parameters_path):
target_dataset.save(parameters_path)

dataset = Dataset.load(parameters_path)
target_config = ClipTextConfig.from_properties(dataset.properties)
input_args = OrderedDict([("input_ids", input_ids)])
batch_size = input_ids.shape[0]

mlir_path = f"{target_model_path_prefix}.mlir"
if not self.caching or not os.path.exists(mlir_path):
export_clip_text_model_mlir(
parameters_path, batch_sizes=[batch_size], mlir_output_path=mlir_path

if (
not self.caching
or not os.path.exists(mlir_path)
or not os.path.exists(parameters_path)
):
export_clip_text_model_iree_test_data(
reference_model=reference_model,
target_dtype=target_dtype,
input_ids=input_ids,
target_mlir_output_path=mlir_path,
target_iree_parameters_output_path=parameters_path,
)

iree_module_path = f"{target_model_path_prefix}.vmfb"
if not self.caching or not os.path.exists(iree_module_path):
iree.compiler.compile_file(
Expand Down Expand Up @@ -211,11 +211,11 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens(
for i in range(len(expected_outputs))
]

actual_last_hidden_states = actual_outputs[0]
expected_last_hidden_states = expected_outputs[0]
actual_last_hidden_state = actual_outputs[0]
expected_last_hidden_state = expected_outputs[0]

assert_text_encoder_state_close(
actual_last_hidden_states, expected_last_hidden_states, atol
actual_last_hidden_state, expected_last_hidden_state, atol
)

def runTestCompareRandomModelIreeAgainstTorch(
Expand Down

0 comments on commit b71e906

Please sign in to comment.