-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CLI script exporting CLIP Toy model IREE test data #672
base: main
Are you sure you want to change the base?
Changes from 1 commit
b71e906
6e37ae9
2dfb5de
8067db9
728cbe2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from argparse import ArgumentParser | ||
from typing import Optional | ||
|
||
from .testing import export_clip_toy_text_model_default_iree_test_data | ||
|
||
|
||
def main(args: Optional[list[str]] = None): | ||
parser = ArgumentParser( | ||
description=( | ||
"Export test data for toy-sized CLIP text model." | ||
" This program MLIR, parameters sample input and expected output." | ||
" Exports float32 and bfloat16 model variants." | ||
" The expected output is always in float32 precision." | ||
) | ||
) | ||
parser.add_argument( | ||
"--output-path-prefix", type=str, default=f"clip_toy_text_model" | ||
) | ||
args = parser.parse_args(args=args) | ||
export_clip_toy_text_model_default_iree_test_data(args.output_path_prefix) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,14 +4,161 @@ | |
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from ...layers.configs.llm_configs import ClipTextConfig | ||
from ...types.theta import Theta | ||
from .export import hugging_face_clip_text_model_to_theta | ||
import functools | ||
import torch | ||
from os import PathLike | ||
from typing import Union, Optional | ||
from copy import copy | ||
from iree.turbine.aot.params import ParameterArchiveBuilder | ||
|
||
from ...layers.configs.llm_configs import ClipTextConfig | ||
from .clip import ClipTextModel | ||
from ...types.theta import Theta, Dataset | ||
from ...types.tensors import dtype_to_serialized_short_name | ||
from ...utils.typing import AnyPath | ||
from ...utils.io import save_tensor_as_irpa | ||
from .export import ( | ||
clip_text_model_to_dataset, | ||
hugging_face_clip_text_model_to_theta, | ||
export_clip_text_model_to_iree, | ||
) | ||
from ...transforms.dataset import set_float_dtype | ||
|
||
|
||
def clip_toy_text_model_config(dtype: torch.dtype) -> ClipTextConfig: | ||
num_attention_heads = 5 | ||
vocab_size = 11 | ||
return ClipTextConfig( | ||
vocab_size=vocab_size, | ||
hidden_size=13 * num_attention_heads, | ||
intermediate_size=7, | ||
projection_dim=3, | ||
num_attention_heads=num_attention_heads, | ||
max_position_embeddings=17, | ||
layer_norm_eps=1e-4, | ||
num_hidden_layers=2, | ||
bos_token_id=vocab_size - 2, | ||
eos_token_id=vocab_size - 1, | ||
dtype=dtype, | ||
) | ||
|
||
|
||
def export_clip_toy_text_model_default_iree_test_data(output_path_prefix: str): | ||
# We want to always export the same without interfering with RNG for the rest of | ||
# the program. | ||
rng_state = torch.get_rng_state() | ||
torch.random.manual_seed(12345) | ||
|
||
reference_dtype = torch.float32 | ||
target_dtypes = [torch.float32, torch.bfloat16] | ||
target_iree_parameters_output_paths = [] | ||
target_mlir_output_paths = [] | ||
batch_size = 4 | ||
for dtype in target_dtypes: | ||
prefix = f"{output_path_prefix}_{dtype_to_serialized_short_name(dtype)}" | ||
target_iree_parameters_output_paths.append(f"{prefix}_parameters.irpa") | ||
target_mlir_output_paths.append(f"{prefix}.mlir") | ||
call_prefix = f"{output_path_prefix}_forward_bs{batch_size}" | ||
input_ids_output_path = f"{call_prefix}_arg0_input_ids.irpa" | ||
expected_last_hidden_state_output_path = ( | ||
f"{call_prefix}_expected_result0_last_hidden_state_" | ||
f"{dtype_to_serialized_short_name(reference_dtype)}.irpa" | ||
) | ||
export_clip_toy_text_model_iree_test_data( | ||
reference_dtype=reference_dtype, | ||
target_dtypes=target_dtypes, | ||
batch_size=batch_size, | ||
input_ids_output_path=input_ids_output_path, | ||
expected_last_hidden_state_output_path=expected_last_hidden_state_output_path, | ||
target_iree_parameters_output_paths=target_iree_parameters_output_paths, | ||
target_mlir_output_paths=target_mlir_output_paths, | ||
) | ||
|
||
torch.set_rng_state(rng_state) | ||
|
||
|
||
def export_clip_toy_text_model_iree_test_data( | ||
reference_dtype: torch.dtype, | ||
target_dtypes: list[torch.dtype], | ||
batch_size: int, | ||
target_iree_parameters_output_paths: list[AnyPath], | ||
target_mlir_output_paths: list[AnyPath], | ||
input_ids_output_path: AnyPath, | ||
expected_last_hidden_state_output_path: AnyPath, | ||
): | ||
reference_config = clip_toy_text_model_config(reference_dtype) | ||
input_ids = make_random_input_token_sequences( | ||
batch_size=batch_size, config=reference_config | ||
) | ||
reference_theta = make_clip_text_model_random_theta(reference_config) | ||
reference_model = ClipTextModel(theta=reference_theta, config=reference_config) | ||
for i, ( | ||
target_dtype, | ||
target_iree_parameters_output_path, | ||
target_mlir_output_path, | ||
) in enumerate( | ||
zip( | ||
target_dtypes, | ||
target_iree_parameters_output_paths, | ||
target_mlir_output_paths, | ||
strict=True, | ||
) | ||
): | ||
export_clip_text_model_iree_test_data( | ||
reference_model=reference_model, | ||
target_dtype=target_dtype, | ||
input_ids=input_ids, | ||
target_iree_parameters_output_path=target_iree_parameters_output_path, | ||
target_mlir_output_path=target_mlir_output_path, | ||
input_ids_output_path=input_ids_output_path if i == 0 else None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Split this across lines or put in an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
expected_last_hidden_state_output_path=expected_last_hidden_state_output_path | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as above. This looks like it should just be in an |
||
if i == 0 | ||
else None, | ||
) | ||
|
||
|
||
def export_clip_text_model_iree_test_data( | ||
reference_model: ClipTextModel, | ||
target_dtype: torch.dtype, | ||
input_ids: torch.LongTensor, | ||
target_mlir_output_path: AnyPath, | ||
target_iree_parameters_output_path: AnyPath, | ||
input_ids_output_path: Optional[AnyPath] = None, | ||
expected_last_hidden_state_output_path: Optional[AnyPath] = None, | ||
): | ||
batch_size = input_ids.shape[0] | ||
reference_dataset = clip_text_model_to_dataset(reference_model) | ||
target_config = copy(reference_model.config) | ||
target_config.dtype = target_dtype | ||
target_dataset = Dataset( | ||
root_theta=reference_dataset.root_theta.transform( | ||
functools.partial(set_float_dtype, dtype=torch.bfloat16) | ||
), | ||
properties=target_config.to_properties(), | ||
) | ||
target_model = ClipTextModel(theta=target_dataset.root_theta, config=target_config) | ||
export_clip_text_model_to_iree( | ||
target_model, | ||
batch_sizes=[batch_size], | ||
mlir_output_path=target_mlir_output_path, | ||
parameters_output_path=target_iree_parameters_output_path, | ||
) | ||
|
||
if input_ids_output_path is not None: | ||
save_tensor_as_irpa(input_ids, input_ids_output_path) | ||
|
||
if expected_last_hidden_state_output_path is None: | ||
return | ||
|
||
expected_last_hidden_state = reference_model(input_ids=input_ids)[ | ||
"last_hidden_state" | ||
] | ||
save_tensor_as_irpa( | ||
expected_last_hidden_state, expected_last_hidden_state_output_path | ||
) | ||
|
||
|
||
def make_clip_text_model_random_theta(config: ClipTextConfig) -> Theta: | ||
from transformers import CLIPTextConfig as HfCLIPTextConfig | ||
from transformers import CLIPTextModel as HfCLIPTextModel | ||
|
||
hf_config = config.to_hugging_face_clip_text_model_config() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from typing import Union | ||
from os import PathLike | ||
|
||
AnyPath = Union[str, PathLike] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,6 +59,7 @@ | |
from sharktank.models.clip.testing import ( | ||
make_random_input_token_sequences, | ||
make_clip_text_model_random_theta, | ||
export_clip_text_model_iree_test_data, | ||
) | ||
from sharktank.models.clip import ( | ||
ClipAttention, | ||
|
@@ -96,6 +97,11 @@ def testSmokeExportLargeF32FromHuggingFace(self): | |
huggingface_repo_id, output_path | ||
) | ||
|
||
def testSmokeExportToyIreeTestData(self): | ||
from sharktank.models.clip.export_toy_text_model_iree_test_data import main | ||
|
||
main([f"--output-path-prefix={self.path_prefix}clip_toy_text_model"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use pythons path library to operate on this instead of string manipulations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. I changed it to use a directory as a prefix. |
||
|
||
@with_clip_data | ||
def testCompareLargeIreeF32AgainstTorchEagerF32(self): | ||
self.runTestCompareIreeAgainstPretrainedTorchEager( | ||
|
@@ -147,30 +153,24 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens( | |
f"{self.path_prefix}{file_artifact_prefix_name}_{target_dtype_name}" | ||
) | ||
|
||
target_config = copy(reference_model.config) | ||
target_config.dtype = target_dtype | ||
reference_dataset = clip_text_model_to_dataset(reference_model) | ||
target_dataset = Dataset( | ||
root_theta=reference_dataset.root_theta.transform( | ||
functools.partial(set_float_dtype, dtype=target_config.dtype) | ||
), | ||
properties=target_config.to_properties(), | ||
) | ||
|
||
parameters_path = f"{target_model_path_prefix}.irpa" | ||
if not self.caching or not os.path.exists(parameters_path): | ||
target_dataset.save(parameters_path) | ||
|
||
dataset = Dataset.load(parameters_path) | ||
target_config = ClipTextConfig.from_properties(dataset.properties) | ||
input_args = OrderedDict([("input_ids", input_ids)]) | ||
batch_size = input_ids.shape[0] | ||
|
||
mlir_path = f"{target_model_path_prefix}.mlir" | ||
if not self.caching or not os.path.exists(mlir_path): | ||
export_clip_text_model_mlir( | ||
parameters_path, batch_sizes=[batch_size], mlir_output_path=mlir_path | ||
|
||
if ( | ||
not self.caching | ||
or not os.path.exists(mlir_path) | ||
or not os.path.exists(parameters_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This conditional feels weird / wrong. Is it checking for defaults? Also it feels particularly weird for a test. Reconsider why this is needed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed the caching. It is useful to speedup iteration, but it does add clutter. |
||
): | ||
export_clip_text_model_iree_test_data( | ||
reference_model=reference_model, | ||
target_dtype=target_dtype, | ||
input_ids=input_ids, | ||
target_mlir_output_path=mlir_path, | ||
target_iree_parameters_output_path=parameters_path, | ||
) | ||
|
||
iree_module_path = f"{target_model_path_prefix}.vmfb" | ||
if not self.caching or not os.path.exists(iree_module_path): | ||
iree.compiler.compile_file( | ||
|
@@ -211,11 +211,11 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens( | |
for i in range(len(expected_outputs)) | ||
] | ||
|
||
actual_last_hidden_states = actual_outputs[0] | ||
expected_last_hidden_states = expected_outputs[0] | ||
actual_last_hidden_state = actual_outputs[0] | ||
expected_last_hidden_state = expected_outputs[0] | ||
|
||
assert_text_encoder_state_close( | ||
actual_last_hidden_states, expected_last_hidden_states, atol | ||
actual_last_hidden_state, expected_last_hidden_state, atol | ||
) | ||
|
||
def runTestCompareRandomModelIreeAgainstTorch( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leave this as
PathLike
and remove theAnyPath
. Users should just make sure to convert the string before invoking. (Same as above)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed AnyPath.