-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add CLI script exporting CLIP Toy model IREE test data
This is required to have an easy way of exporting test data that will be used in IREE to guard against regressions. E.g. ``` python -m sharktank.models.clip.export_toy_text_model_iree_test_data \ --output-path-prefix=clip_toy_text_model ``` Refactor some of the existing tests to reuse the new export logic.
- Loading branch information
Showing
6 changed files
with
259 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
30 changes: 30 additions & 0 deletions
30
sharktank/sharktank/models/clip/export_toy_text_model_iree_test_data.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from argparse import ArgumentParser | ||
from typing import Optional | ||
|
||
from .testing import export_clip_toy_text_model_default_iree_test_data | ||
|
||
|
||
def main(args: Optional[list[str]] = None): | ||
parser = ArgumentParser( | ||
description=( | ||
"Export test data for toy-sized CLIP text model." | ||
" This program MLIR, parameters sample input and expected output." | ||
" Exports float32 and bfloat16 model variants." | ||
" The expected output is always in float32 precision." | ||
) | ||
) | ||
parser.add_argument( | ||
"--output-path-prefix", type=str, default=f"clip_toy_text_model" | ||
) | ||
args = parser.parse_args(args=args) | ||
export_clip_toy_text_model_default_iree_test_data(args.output_path_prefix) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from typing import Union | ||
from os import PathLike | ||
|
||
AnyPath = Union[str, PathLike] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters