From 67be84b571e1a912d7d391f3c51dfd08cf40206c Mon Sep 17 00:00:00 2001 From: Guang Yang Date: Fri, 13 Sep 2024 17:18:24 -0700 Subject: [PATCH] =?UTF-8?q?Script=20to=20export=20=F0=9F=A4=97=20models=20?= =?UTF-8?q?(#4723)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: bypass-github-export-checks [Done] ~~Require PR [Make StaticCache configurable at model construct time](https://github.com/huggingface/transformers/pull/32830) in order to export, lower and run the 🤗 model OOTB.~~ [Done] ~~Require https://github.com/huggingface/transformers/pull/33303 or https://github.com/huggingface/transformers/pull/33287 to be merged to 🤗 `transformers` to resolve the export issue introduced by https://github.com/huggingface/transformers/pull/32543~~ ----------- Now we can take the integration point from 🤗 `transformers` to lower compatible models to ExecuTorch OOTB. - This PR creates a simple script with recipe of XNNPACK. - This PR also created a secret `EXECUTORCH_HT_TOKEN` to allow download checkpoints in the CI - This PR connects the 🤗 "Export to ExecuTorch" e2e workflow to ExecuTorch CI ### Instructions to run the demo: 1. Run the export_hf_model.py to lower gemma-2b to ExecuTorch: ``` python -m extension.export_util.export_hf_model -hfm "google/gemma-2b" # The model is exported statical dims with static KV cache ``` 2. Run the tokenizer.py to generate the binary format for ExecuTorch runtime: ``` python -m extension.llm.tokenizer.tokenizer -t /tokenizer.model -o tokenizer.bin ``` 3. Build llm runner by following this guide [step 4](https://github.com/pytorch/executorch/tree/main/examples/models/llama2#step-4-run-on-your-computer-to-validate) 4. Run the lowered model ``` cmake-out/examples/models/llama2/llama_main --model_path=gemma.pte --tokenizer_path=tokenizer.bin --prompt="My name is" ``` OOTB output and perf ``` I 00:00:00.003110 executorch:cpuinfo_utils.cpp:62] Reading file /sys/devices/soc0/image_version I 00:00:00.003360 executorch:cpuinfo_utils.cpp:78] Failed to open midr file /sys/devices/soc0/image_version I 00:00:00.003380 executorch:cpuinfo_utils.cpp:158] Number of efficient cores 4 I 00:00:00.003384 executorch:main.cpp:65] Resetting threadpool with num threads = 6 I 00:00:00.014716 executorch:runner.cpp:51] Creating LLaMa runner: model_path=gemma.pte, tokenizer_path=tokenizer_gemma.bin I 00:00:03.065359 executorch:runner.cpp:66] Reading metadata from model I 00:00:03.065391 executorch:metadata_util.h:43] get_n_bos: 1 I 00:00:03.065396 executorch:metadata_util.h:43] get_n_eos: 1 I 00:00:03.065399 executorch:metadata_util.h:43] get_max_seq_len: 123 I 00:00:03.065402 executorch:metadata_util.h:43] use_kv_cache: 1 I 00:00:03.065404 executorch:metadata_util.h:41] The model does not contain use_sdpa_with_kv_cache method, using default value 0 I 00:00:03.065405 executorch:metadata_util.h:43] use_sdpa_with_kv_cache: 0 I 00:00:03.065407 executorch:metadata_util.h:41] The model does not contain append_eos_to_prompt method, using default value 0 I 00:00:03.065409 executorch:metadata_util.h:43] append_eos_to_prompt: 0 I 00:00:03.065411 executorch:metadata_util.h:41] The model does not contain enable_dynamic_shape method, using default value 0 I 00:00:03.065412 executorch:metadata_util.h:43] enable_dynamic_shape: 0 I 00:00:03.130388 executorch:metadata_util.h:43] get_vocab_size: 256000 I 00:00:03.130405 executorch:metadata_util.h:43] get_bos_id: 2 I 00:00:03.130408 executorch:metadata_util.h:43] get_eos_id: 1 My name is Melle. I am a 20 year old girl from Belgium. I am living in the southern part of Belgium. I am 165 cm tall and I weigh 45kg. I like to play sports like swimming, running and playing tennis. I am very interested in music and I like to listen to classical music. I like to sing and I can play the piano. I would like to go to the USA because I like to travel a lot. I am looking for a boy from the USA who is between 18 and 25 years old. I PyTorchObserver {"prompt_tokens":4,"generated_tokens":118,"model_load_start_ms":1723685715497,"model_load_end_ms":1723685718612,"inference_start_ms":1723685718612,"inference_end_ms":1723685732965,"prompt_eval_end_ms":1723685719087,"first_token_ms":1723685719087,"aggregate_sampling_time_ms":182,"SCALING_FACTOR_UNITS_PER_SECOND":1000} I 00:00:17.482472 executorch:stats.h:70] Prompt Tokens: 4 Generated Tokens: 118 I 00:00:17.482475 executorch:stats.h:76] Model Load Time: 3.115000 (seconds) I 00:00:17.482481 executorch:stats.h:86] Total inference time: 14.353000 (seconds) Rate: 8.221278 (tokens/second) I 00:00:17.482483 executorch:stats.h:94] Prompt evaluation: 0.475000 (seconds) Rate: 8.421053 (tokens/second) I 00:00:17.482485 executorch:stats.h:105] Generated 118 tokens: 13.878000 (seconds) Rate: 8.502666 (tokens/second) I 00:00:17.482486 executorch:stats.h:113] Time to first generated token: 0.475000 (seconds) I 00:00:17.482488 executorch:stats.h:120] Sampling time over 122 tokens: 0.182000 (seconds) ``` Pull Request resolved: https://github.com/pytorch/executorch/pull/4723 Reviewed By: huydhn, kirklandsign Differential Revision: D62543933 Pulled By: guangy10 fbshipit-source-id: 00401a39ba03d7383e4b284d25c8fc62a6695b34 --- .github/workflows/trunk.yml | 90 +++++++++++++++++++ extension/export_util/export_hf_model.py | 110 +++++++++++++++++++++++ 2 files changed, 200 insertions(+) create mode 100644 extension/export_util/export_hf_model.py diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index c1a0d175d0..9d50420e9f 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -351,3 +351,93 @@ jobs: PYTHON_EXECUTABLE=python ${CONDA_RUN} bash .ci/scripts/test_model.sh "${MODEL_NAME}" "${BUILD_TOOL}" "${BACKEND}" echo "::endgroup::" done + + test-huggingface-transformers: + name: test-huggingface-transformers + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + secrets: inherit + strategy: + matrix: + hf_model_repo: [google/gemma-2b] + fail-fast: false + with: + secrets-env: EXECUTORCH_HF_TOKEN + runner: linux.12xlarge + docker-image: executorch-ubuntu-22.04-clang12 + submodules: 'true' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + script: | + echo "::group::Set up ExecuTorch" + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh cmake + + echo "Installing libexecutorch.a, libextension_module.so, libportable_ops_lib.a" + rm -rf cmake-out + cmake \ + -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DCMAKE_BUILD_TYPE=Release \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ + -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ + -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ + -DEXECUTORCH_BUILD_XNNPACK=ON \ + -DPYTHON_EXECUTABLE=python \ + -Bcmake-out . + cmake --build cmake-out -j9 --target install --config Release + + echo "Build llama runner" + dir="examples/models/llama2" + cmake \ + -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DCMAKE_BUILD_TYPE=Release \ + -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ + -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ + -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ + -DEXECUTORCH_BUILD_XNNPACK=ON \ + -DPYTHON_EXECUTABLE=python \ + -Bcmake-out/${dir} \ + ${dir} + cmake --build cmake-out/${dir} -j9 --config Release + echo "::endgroup::" + + echo "::group::Set up HuggingFace Dependencies" + pip install -U "huggingface_hub[cli]" + huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + pip install accelerate sentencepiece + # TODO(guangyang): Switch to use released transformers library after all required patches are included + pip install "git+https://github.com/huggingface/transformers.git@6cc4dfe3f1e8d421c6d6351388e06e9b123cbfe1" + pip list + echo "::endgroup::" + + echo "::group::Export to ExecuTorch" + TOKENIZER_FILE=tokenizer.model + TOKENIZER_BIN_FILE=tokenizer.bin + ET_MODEL_NAME=et_model + # Fetch the file using a Python one-liner + DOWNLOADED_TOKENIZER_FILE_PATH=$(python -c " + from huggingface_hub import hf_hub_download + # Download the file from the Hugging Face Hub + downloaded_path = hf_hub_download( + repo_id='${{ matrix.hf_model_repo }}', + filename='${TOKENIZER_FILE}' + ) + print(downloaded_path) + ") + if [ -f "$DOWNLOADED_TOKENIZER_FILE_PATH" ]; then + echo "${TOKENIZER_FILE} downloaded successfully at: $DOWNLOADED_TOKENIZER_FILE_PATH" + python -m extension.llm.tokenizer.tokenizer -t $DOWNLOADED_TOKENIZER_FILE_PATH -o ./${TOKENIZER_BIN_FILE} + ls ./tokenizer.bin + else + echo "Failed to download ${TOKENIZER_FILE} from ${{ matrix.hf_model_repo }}." + exit 1 + fi + + python -m extension.export_util.export_hf_model -hfm=${{ matrix.hf_model_repo }} -o ${ET_MODEL_NAME} + + cmake-out/examples/models/llama2/llama_main --model_path=${ET_MODEL_NAME}.pte --tokenizer_path=${TOKENIZER_BIN_FILE} --prompt="My name is" + echo "::endgroup::" diff --git a/extension/export_util/export_hf_model.py b/extension/export_util/export_hf_model.py new file mode 100644 index 0000000000..12ed202988 --- /dev/null +++ b/extension/export_util/export_hf_model.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch +import torch.export._trace +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge +from torch.nn.attention import SDPBackend +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.generation.configuration_utils import GenerationConfig +from transformers.integrations.executorch import convert_and_export_with_cache +from transformers.modeling_utils import PreTrainedModel + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "-hfm", + "--hf_model_repo", + required=True, + default=None, + help="a valid huggingface model repo name", + ) + parser.add_argument( + "-o", + "--output_name", + required=False, + default=None, + help="output name of the exported model", + ) + + args = parser.parse_args() + + # Configs to HF model + device = "cpu" + dtype = torch.float32 + batch_size = 1 + max_length = 123 + cache_implementation = "static" + attn_implementation = "sdpa" + + # Load and configure a HF model + model = AutoModelForCausalLM.from_pretrained( + args.hf_model_repo, + attn_implementation=attn_implementation, + device_map=device, + torch_dtype=dtype, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_length, + }, + ), + ) + print(f"{model.config}") + print(f"{model.generation_config}") + + tokenizer = AutoTokenizer.from_pretrained(args.hf_model_repo) + input_ids = tokenizer([""], return_tensors="pt").to(device)["input_ids"] + cache_position = torch.tensor([0], dtype=torch.long) + + def _get_constant_methods(model: PreTrainedModel): + return { + "get_dtype": 5 if model.config.torch_dtype == torch.float16 else 6, + "get_bos_id": model.config.bos_token_id, + "get_eos_id": model.config.eos_token_id, + "get_head_dim": model.config.hidden_size / model.config.num_attention_heads, + "get_max_batch_size": model.generation_config.cache_config.batch_size, + "get_max_seq_len": model.generation_config.cache_config.max_cache_len, + "get_n_bos": 1, + "get_n_eos": 1, + "get_n_kv_heads": model.config.num_key_value_heads, + "get_n_layers": model.config.num_hidden_layers, + "get_vocab_size": model.config.vocab_size, + "use_kv_cache": model.generation_config.use_cache, + } + + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): + + exported_prog = convert_and_export_with_cache(model, input_ids, cache_position) + prog = ( + to_edge( + exported_prog, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods=_get_constant_methods(model), + ) + .to_backend(XnnpackPartitioner()) + .to_executorch(ExecutorchBackendConfig(extract_delegate_segments=True)) + ) + out_name = args.output_name if args.output_name else model.config.model_type + filename = os.path.join("./", f"{out_name}.pte") + with open(filename, "wb") as f: + prog.write_to_file(f) + print(f"Saved exported program to {filename}") + + +if __name__ == "__main__": + main()