Skip to content

Commit

Permalink
Fix parity checker in LLaMA scripts (microsoft#20301)
Browse files Browse the repository at this point in the history
### Description
This PR fixes the parity checker in the LLaMA scripts by adding the
following.
- Enable buffer sharing manually with `use_buffer_share` instead of
`use_gqa`
- Get max sequence length from model's config


### Motivation and Context
This PR fixes an issue with running the parity checker on other
large-language models where `GroupQueryAttention` can be used without
buffer sharing enabled.
  • Loading branch information
kunal-vaishnavi authored Apr 15, 2024
1 parent bf72f99 commit 6e4516c
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 46 deletions.
4 changes: 2 additions & 2 deletions onnxruntime/python/tools/transformers/models/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,13 @@ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.llama_parity \
--cache_dir ./model_cache \
```

4. Merged ONNX model, FP16 CUDA with GroupQueryAttention
4. Merged ONNX model, FP16 CUDA with GroupQueryAttention + Buffer Sharing Enabled
```
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.llama_parity \
--model_name meta-llama/Llama-2-7b-hf \
--onnx_model_path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \
--merged \
--use_gqa \
--use_buffer_share \
--execution_provider cuda \
--precision fp16 \
--cache_dir ./model_cache \
Expand Down
28 changes: 11 additions & 17 deletions onnxruntime/python/tools/transformers/models/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
init_inputs, iter_inputs = None, None

# For past_present_share_buffer:
# Set max_seq_len to 16384 for CodeLLaMA (finetuned variant of LLaMA-2)
# Set max_seq_len to 4096 for Hugging Face LLaMA-2 model since that is the default value
# Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported
temp_name = args.model_name.lower().replace("-", "").replace("_", "")
max_seq_len = (
2048
if args.benchmark_type == "ort-msft"
else 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048
)
# Set max_seq_len to config value for other models
max_seq_len = 2048 if args.benchmark_type == "ort-msft" else args.config.max_position_embeddings

if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
init_inputs = get_sample_inputs(
Expand Down Expand Up @@ -109,7 +103,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
past_seq_len=0,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_gqa=args.use_gqa,
use_buffer_share=args.use_buffer_share,
engine="pt",
return_dict=True,
)
Expand All @@ -121,7 +115,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
past_seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_gqa=args.use_gqa,
use_buffer_share=args.use_buffer_share,
engine="pt",
return_dict=True,
)
Expand All @@ -136,7 +130,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
past_seq_len=0,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_gqa=args.use_gqa,
use_buffer_share=args.use_buffer_share,
engine="ort",
return_dict=True,
world_size=args.world_size,
Expand All @@ -149,7 +143,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
past_seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_gqa=args.use_gqa,
use_buffer_share=args.use_buffer_share,
engine="ort",
return_dict=True,
world_size=args.world_size,
Expand All @@ -166,7 +160,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_gqa=args.use_gqa,
use_buffer_share=args.use_buffer_share,
split_kv=split_kv,
)
iter_inputs = get_msft_sample_inputs(
Expand All @@ -176,7 +170,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
seq_len=1,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_gqa=args.use_gqa,
use_buffer_share=args.use_buffer_share,
split_kv=split_kv,
)

Expand Down Expand Up @@ -457,7 +451,7 @@ def prepare_ort_inputs(inputs, kv_cache_ortvalues):
# Add IO bindings for non-CPU execution providers
if args.device != "cpu":
io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
model, inputs, args.device, int(args.rank), args.use_gqa, kv_cache_ortvalues
model, inputs, args.device, int(args.rank), args.use_buffer_share, kv_cache_ortvalues
)
setattr(args, "io_binding", io_binding) # noqa: B010
return io_binding, kv_cache_ortvalues
Expand Down Expand Up @@ -684,9 +678,9 @@ def main():
gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node))

use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu"
setattr(args, "use_gqa", use_buffer_share) # noqa: B010
setattr(args, "use_buffer_share", use_buffer_share) # noqa: B010
else:
setattr(args, "use_gqa", False) # noqa: B010
setattr(args, "use_buffer_share", False) # noqa: B010

# Measure prompt cost (init_inputs) and generated token cost (iter_inputs)
for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ def main():
os.path.join(args.output, filename),
"-ep",
args.execution_provider,
"-fp",
"--precision",
args.precision,
"--cache_dir",
args.cache_dir,
Expand All @@ -1042,8 +1042,6 @@ def main():
parity_cmd.append("--use_past_kv")
if "merged" in filename:
parity_cmd.append("--merged")
if args.use_gqa:
parity_cmd.append("--use_gqa")

try:
logger.info(f"check parity with cmd: {parity_cmd}")
Expand Down
23 changes: 14 additions & 9 deletions onnxruntime/python/tools/transformers/models/llama/llama_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def get_merged_sample_with_past_kv_inputs(
past_seq_len: int,
max_seq_len: int,
use_fp16: bool = False,
use_gqa: bool = False,
use_buffer_share: bool = False,
engine: str = "pt",
return_dict: bool = False,
world_size: int = 1,
Expand Down Expand Up @@ -162,7 +162,7 @@ def get_merged_sample_with_past_kv_inputs(
assert isinstance(past_kv, dict)
inputs.update(past_kv)

if use_gqa:
if use_buffer_share:
inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len)

else:
Expand All @@ -180,7 +180,7 @@ def get_msft_sample_inputs(
seq_len: int,
max_seq_len: int,
use_fp16: bool,
use_gqa: bool,
use_buffer_share: bool,
split_kv: bool,
):
np_dtype = np.float16 if use_fp16 else np.float32
Expand Down Expand Up @@ -218,7 +218,7 @@ def get_msft_sample_inputs(
}
)

if use_gqa:
if use_buffer_share:
ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)

return ort_inputs
Expand Down Expand Up @@ -252,7 +252,7 @@ def flatten_past_kv_inputs(past_key_values: list[tuple[torch.Tensor, torch.Tenso
# Format PyTorch inputs to ONNX Runtime inputs
def convert_inputs_for_ort(
pt_inputs: dict,
use_gqa: bool = False,
use_buffer_share: bool = False,
past_seq_len: int = 0,
max_seq_len: int = 2048,
device: str = "",
Expand All @@ -268,7 +268,7 @@ def convert_inputs_for_ort(
ort_inputs[k] = v.detach().cpu().numpy()

# Reshape KV caches if using past-present-share-buffer
if use_gqa and device != "" and device != "cpu" and device_id > -1:
if use_buffer_share and device != "" and device != "cpu" and device_id > -1:
ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)

return ort_inputs
Expand Down Expand Up @@ -311,7 +311,12 @@ def verify_ort_inputs(model: InferenceSession, ort_inputs: dict):
# Add IO bindings for execution providers using OrtValue
# Use when you need to run inference once or twice to save memory
def add_io_bindings_as_ortvalues(
model: InferenceSession, ort_inputs: dict, device: str, device_id: int, use_gqa: bool, kv_cache_ortvalues: dict
model: InferenceSession,
ort_inputs: dict,
device: str,
device_id: int,
use_buffer_share: bool,
kv_cache_ortvalues: dict,
):
io_binding = model.io_binding()

Expand All @@ -324,7 +329,7 @@ def add_io_bindings_as_ortvalues(
continue

# Bind OrtValue inputs to device
if use_gqa and ("cache" in k or "past_key_values" in k):
if use_buffer_share and ("cache" in k or "past_key_values" in k):
if k not in kv_cache_ortvalues:
v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
io_binding.bind_ortvalue_input(k, v_device)
Expand All @@ -338,7 +343,7 @@ def add_io_bindings_as_ortvalues(

for output in model.get_outputs():
name = output.name
if use_gqa and ("out" in name or "present" in name):
if use_buffer_share and ("out" in name or "present" in name):
# Bind present KV cache outputs to past KV cache inputs in order to buffer share
input_name = name.replace("out", "cache").replace("present", "past_key_values")
io_binding.bind_ortvalue_output(name, kv_cache_ortvalues[input_name])
Expand Down
29 changes: 14 additions & 15 deletions onnxruntime/python/tools/transformers/models/llama/llama_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,17 @@
logger = logging.getLogger("")


def get_sequence_lengths(args: argparse.Namespace):
def get_sequence_lengths(args: argparse.Namespace, config: AutoConfig):
past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8)
temp_name = args.model_name.lower().replace("-", "").replace("_", "")
max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048
max_sequence_length = config.max_position_embeddings
return past_sequence_length, curr_sequence_length, max_sequence_length


def get_inputs(args: argparse.Namespace, config: AutoConfig):
# Dummy values for parity
world_size = get_size()
batch_size = 2
past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args)
past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args, config)

if args.merged:
inputs = get_merged_sample_with_past_kv_inputs(
Expand All @@ -51,7 +50,7 @@ def get_inputs(args: argparse.Namespace, config: AutoConfig):
past_seq_len=past_sequence_length,
max_seq_len=max_sequence_length,
use_fp16=args.use_fp16,
use_gqa=args.use_gqa,
use_buffer_share=args.use_buffer_share,
return_dict=True,
world_size=world_size,
)
Expand Down Expand Up @@ -107,10 +106,10 @@ def verify_parity(
torch.cuda.empty_cache()

# Run inference with ORT
past_sequence_length, _, max_sequence_length = get_sequence_lengths(args)
past_sequence_length, _, max_sequence_length = get_sequence_lengths(args, config)
inputs = convert_inputs_for_ort(
inputs,
use_gqa=args.use_gqa,
use_buffer_share=args.use_buffer_share,
past_seq_len=past_sequence_length,
max_seq_len=max_sequence_length,
device=args.execution_provider,
Expand All @@ -130,11 +129,11 @@ def verify_parity(
if args.execution_provider != "cpu":
io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
ort_model,
inputs,
args.execution_provider,
int(args.rank),
args.use_gqa,
kv_cache_ortvalues,
ort_inputs=inputs,
device=args.execution_provider,
device_id=int(args.rank),
use_buffer_share=args.use_buffer_share,
kv_cache_ortvalues=kv_cache_ortvalues,
)

io_binding.synchronize_inputs()
Expand Down Expand Up @@ -217,11 +216,11 @@ def get_args(argv: list[str]):

parser.add_argument(
"-g",
"--use_gqa",
"--use_buffer_share",
action="store_true",
help="Use if model has GroupQueryAttention",
help="Use if model has GroupQueryAttention and you want to enable past-present buffer sharing",
)
parser.set_defaults(use_gqa=False)
parser.set_defaults(use_buffer_share=False)

parser.add_argument(
"--merged",
Expand Down

0 comments on commit 6e4516c

Please sign in to comment.