From 6e4516cef9254820639693a92dd2a53b7613ca23 Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Sun, 14 Apr 2024 17:59:14 -0700 Subject: [PATCH] Fix parity checker in LLaMA scripts (#20301) ### Description This PR fixes the parity checker in the LLaMA scripts by adding the following. - Enable buffer sharing manually with `use_buffer_share` instead of `use_gqa` - Get max sequence length from model's config ### Motivation and Context This PR fixes an issue with running the parity checker on other large-language models where `GroupQueryAttention` can be used without buffer sharing enabled. --- .../tools/transformers/models/llama/README.md | 4 +-- .../transformers/models/llama/benchmark.py | 28 +++++++----------- .../models/llama/convert_to_onnx.py | 4 +-- .../transformers/models/llama/llama_inputs.py | 23 +++++++++------ .../transformers/models/llama/llama_parity.py | 29 +++++++++---------- 5 files changed, 42 insertions(+), 46 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index 2e8cd3e1ac7f9..04671e47c033c 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -273,13 +273,13 @@ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.llama_parity \ --cache_dir ./model_cache \ ``` -4. Merged ONNX model, FP16 CUDA with GroupQueryAttention +4. Merged ONNX model, FP16 CUDA with GroupQueryAttention + Buffer Sharing Enabled ``` CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.llama_parity \ --model_name meta-llama/Llama-2-7b-hf \ --onnx_model_path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \ --merged \ - --use_gqa \ + --use_buffer_share \ --execution_provider cuda \ --precision fp16 \ --cache_dir ./model_cache \ diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py index 6184298c471ac..b9d6b30baae8b 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py @@ -54,15 +54,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): init_inputs, iter_inputs = None, None # For past_present_share_buffer: - # Set max_seq_len to 16384 for CodeLLaMA (finetuned variant of LLaMA-2) - # Set max_seq_len to 4096 for Hugging Face LLaMA-2 model since that is the default value # Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported - temp_name = args.model_name.lower().replace("-", "").replace("_", "") - max_seq_len = ( - 2048 - if args.benchmark_type == "ort-msft" - else 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048 - ) + # Set max_seq_len to config value for other models + max_seq_len = 2048 if args.benchmark_type == "ort-msft" else args.config.max_position_embeddings if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: init_inputs = get_sample_inputs( @@ -109,7 +103,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): past_seq_len=0, max_seq_len=max_seq_len, use_fp16=args.use_fp16, - use_gqa=args.use_gqa, + use_buffer_share=args.use_buffer_share, engine="pt", return_dict=True, ) @@ -121,7 +115,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): past_seq_len=args.sequence_length, max_seq_len=max_seq_len, use_fp16=args.use_fp16, - use_gqa=args.use_gqa, + use_buffer_share=args.use_buffer_share, engine="pt", return_dict=True, ) @@ -136,7 +130,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): past_seq_len=0, max_seq_len=max_seq_len, use_fp16=args.use_fp16, - use_gqa=args.use_gqa, + use_buffer_share=args.use_buffer_share, engine="ort", return_dict=True, world_size=args.world_size, @@ -149,7 +143,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): past_seq_len=args.sequence_length, max_seq_len=max_seq_len, use_fp16=args.use_fp16, - use_gqa=args.use_gqa, + use_buffer_share=args.use_buffer_share, engine="ort", return_dict=True, world_size=args.world_size, @@ -166,7 +160,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): seq_len=args.sequence_length, max_seq_len=max_seq_len, use_fp16=args.use_fp16, - use_gqa=args.use_gqa, + use_buffer_share=args.use_buffer_share, split_kv=split_kv, ) iter_inputs = get_msft_sample_inputs( @@ -176,7 +170,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): seq_len=1, max_seq_len=max_seq_len, use_fp16=args.use_fp16, - use_gqa=args.use_gqa, + use_buffer_share=args.use_buffer_share, split_kv=split_kv, ) @@ -457,7 +451,7 @@ def prepare_ort_inputs(inputs, kv_cache_ortvalues): # Add IO bindings for non-CPU execution providers if args.device != "cpu": io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues( - model, inputs, args.device, int(args.rank), args.use_gqa, kv_cache_ortvalues + model, inputs, args.device, int(args.rank), args.use_buffer_share, kv_cache_ortvalues ) setattr(args, "io_binding", io_binding) # noqa: B010 return io_binding, kv_cache_ortvalues @@ -684,9 +678,9 @@ def main(): gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node)) use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu" - setattr(args, "use_gqa", use_buffer_share) # noqa: B010 + setattr(args, "use_buffer_share", use_buffer_share) # noqa: B010 else: - setattr(args, "use_gqa", False) # noqa: B010 + setattr(args, "use_buffer_share", False) # noqa: B010 # Measure prompt cost (init_inputs) and generated token cost (iter_inputs) for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths): diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index b649f7ab65049..9990c1d006c1c 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -1029,7 +1029,7 @@ def main(): os.path.join(args.output, filename), "-ep", args.execution_provider, - "-fp", + "--precision", args.precision, "--cache_dir", args.cache_dir, @@ -1042,8 +1042,6 @@ def main(): parity_cmd.append("--use_past_kv") if "merged" in filename: parity_cmd.append("--merged") - if args.use_gqa: - parity_cmd.append("--use_gqa") try: logger.info(f"check parity with cmd: {parity_cmd}") diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index 5aed55c12f38f..7b3caf0b7017e 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -127,7 +127,7 @@ def get_merged_sample_with_past_kv_inputs( past_seq_len: int, max_seq_len: int, use_fp16: bool = False, - use_gqa: bool = False, + use_buffer_share: bool = False, engine: str = "pt", return_dict: bool = False, world_size: int = 1, @@ -162,7 +162,7 @@ def get_merged_sample_with_past_kv_inputs( assert isinstance(past_kv, dict) inputs.update(past_kv) - if use_gqa: + if use_buffer_share: inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len) else: @@ -180,7 +180,7 @@ def get_msft_sample_inputs( seq_len: int, max_seq_len: int, use_fp16: bool, - use_gqa: bool, + use_buffer_share: bool, split_kv: bool, ): np_dtype = np.float16 if use_fp16 else np.float32 @@ -218,7 +218,7 @@ def get_msft_sample_inputs( } ) - if use_gqa: + if use_buffer_share: ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len) return ort_inputs @@ -252,7 +252,7 @@ def flatten_past_kv_inputs(past_key_values: list[tuple[torch.Tensor, torch.Tenso # Format PyTorch inputs to ONNX Runtime inputs def convert_inputs_for_ort( pt_inputs: dict, - use_gqa: bool = False, + use_buffer_share: bool = False, past_seq_len: int = 0, max_seq_len: int = 2048, device: str = "", @@ -268,7 +268,7 @@ def convert_inputs_for_ort( ort_inputs[k] = v.detach().cpu().numpy() # Reshape KV caches if using past-present-share-buffer - if use_gqa and device != "" and device != "cpu" and device_id > -1: + if use_buffer_share and device != "" and device != "cpu" and device_id > -1: ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len) return ort_inputs @@ -311,7 +311,12 @@ def verify_ort_inputs(model: InferenceSession, ort_inputs: dict): # Add IO bindings for execution providers using OrtValue # Use when you need to run inference once or twice to save memory def add_io_bindings_as_ortvalues( - model: InferenceSession, ort_inputs: dict, device: str, device_id: int, use_gqa: bool, kv_cache_ortvalues: dict + model: InferenceSession, + ort_inputs: dict, + device: str, + device_id: int, + use_buffer_share: bool, + kv_cache_ortvalues: dict, ): io_binding = model.io_binding() @@ -324,7 +329,7 @@ def add_io_bindings_as_ortvalues( continue # Bind OrtValue inputs to device - if use_gqa and ("cache" in k or "past_key_values" in k): + if use_buffer_share and ("cache" in k or "past_key_values" in k): if k not in kv_cache_ortvalues: v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id) io_binding.bind_ortvalue_input(k, v_device) @@ -338,7 +343,7 @@ def add_io_bindings_as_ortvalues( for output in model.get_outputs(): name = output.name - if use_gqa and ("out" in name or "present" in name): + if use_buffer_share and ("out" in name or "present" in name): # Bind present KV cache outputs to past KV cache inputs in order to buffer share input_name = name.replace("out", "cache").replace("present", "past_key_values") io_binding.bind_ortvalue_output(name, kv_cache_ortvalues[input_name]) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 7b186eec2f5a9..c3e754e3df5a5 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -29,10 +29,9 @@ logger = logging.getLogger("") -def get_sequence_lengths(args: argparse.Namespace): +def get_sequence_lengths(args: argparse.Namespace, config: AutoConfig): past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8) - temp_name = args.model_name.lower().replace("-", "").replace("_", "") - max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048 + max_sequence_length = config.max_position_embeddings return past_sequence_length, curr_sequence_length, max_sequence_length @@ -40,7 +39,7 @@ def get_inputs(args: argparse.Namespace, config: AutoConfig): # Dummy values for parity world_size = get_size() batch_size = 2 - past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args) + past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args, config) if args.merged: inputs = get_merged_sample_with_past_kv_inputs( @@ -51,7 +50,7 @@ def get_inputs(args: argparse.Namespace, config: AutoConfig): past_seq_len=past_sequence_length, max_seq_len=max_sequence_length, use_fp16=args.use_fp16, - use_gqa=args.use_gqa, + use_buffer_share=args.use_buffer_share, return_dict=True, world_size=world_size, ) @@ -107,10 +106,10 @@ def verify_parity( torch.cuda.empty_cache() # Run inference with ORT - past_sequence_length, _, max_sequence_length = get_sequence_lengths(args) + past_sequence_length, _, max_sequence_length = get_sequence_lengths(args, config) inputs = convert_inputs_for_ort( inputs, - use_gqa=args.use_gqa, + use_buffer_share=args.use_buffer_share, past_seq_len=past_sequence_length, max_seq_len=max_sequence_length, device=args.execution_provider, @@ -130,11 +129,11 @@ def verify_parity( if args.execution_provider != "cpu": io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues( ort_model, - inputs, - args.execution_provider, - int(args.rank), - args.use_gqa, - kv_cache_ortvalues, + ort_inputs=inputs, + device=args.execution_provider, + device_id=int(args.rank), + use_buffer_share=args.use_buffer_share, + kv_cache_ortvalues=kv_cache_ortvalues, ) io_binding.synchronize_inputs() @@ -217,11 +216,11 @@ def get_args(argv: list[str]): parser.add_argument( "-g", - "--use_gqa", + "--use_buffer_share", action="store_true", - help="Use if model has GroupQueryAttention", + help="Use if model has GroupQueryAttention and you want to enable past-present buffer sharing", ) - parser.set_defaults(use_gqa=False) + parser.set_defaults(use_buffer_share=False) parser.add_argument( "--merged",