-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce CausalLMModel intefrace and add IREE numerics test for Llama 3.1 8B FP16 TP8 #375
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,16 +13,20 @@ | |
|
||
from sharktank.layers import * | ||
from sharktank.types import * | ||
from sharktank.utils.math import ceildiv | ||
|
||
# TODO: Should be using a base class with the protocol supported. | ||
from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 | ||
from ..models.llama.sharding import shard_theta | ||
from ..models.mixtral.mixtral import * | ||
from ..models.grok.grok import * | ||
from .. import ops | ||
|
||
|
||
def main(): | ||
def dtype_from_str(s: str) -> torch.dtype: | ||
parts = s.split(".", maxsplit=1) | ||
assert parts != 2 or parts[0] != "torch" | ||
return getattr(torch, parts[1]) | ||
|
||
|
||
def main(raw_args: list[str] | None = None): | ||
from ..utils import cli | ||
|
||
parser = cli.create_parser() | ||
|
@@ -43,6 +47,12 @@ def main(): | |
type=lambda arg: [int(bs) for bs in arg.split(",")], | ||
default="4", | ||
) | ||
parser.add_argument( | ||
"--block-seq-stride", | ||
help="Block sequence stride for a paged KV cache.", | ||
type=int, | ||
default=LlamaModelConfig.default_block_seq_stride, | ||
) | ||
parser.add_argument( | ||
"--verbose", | ||
help="Include verbose logging", | ||
|
@@ -59,8 +69,18 @@ def main(): | |
default="decomposed", | ||
choices=["decomposed", "torch"], | ||
) | ||
parser.add_argument( | ||
"--attention-dtype", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you provide a justification for adding these in the first place? They should be inferred from the data types of the functions There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is one test that uses fp32. |
||
type=str, | ||
default=str(LlamaModelConfig.default_attention_dtype), | ||
) | ||
parser.add_argument( | ||
"--activation-dtype", | ||
type=str, | ||
default=str(LlamaModelConfig.default_activation_dtype), | ||
) | ||
|
||
args = cli.parse(parser) | ||
args = cli.parse(parser, args=raw_args) | ||
dataset_type = cli.get_input_data_files(args) | ||
dataset_type = "irpa" if "irpa" in dataset_type else "gguf" | ||
dataset = cli.get_input_dataset(args) | ||
|
@@ -73,11 +93,14 @@ def main(): | |
) | ||
llama_config = LlamaModelConfig( | ||
hp, | ||
block_seq_stride=args.block_seq_stride, | ||
tensor_parallelism_size=tensor_parallelism_size, | ||
use_hf=False, | ||
static_tables=False, # Rely on the compiler for hoisting tables. | ||
kv_cache_type="direct" if args.bs == [1] else "paged", | ||
attention_kernel=args.attention_kernel, | ||
attention_dtype=dtype_from_str(args.attention_dtype), | ||
activation_dtype=dtype_from_str(args.activation_dtype), | ||
) | ||
|
||
if llama_config.hp.expert_count: | ||
|
@@ -110,7 +133,7 @@ def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]): | |
|
||
fxb = FxProgramsBuilder(model) | ||
|
||
def setup_cache(model, shard_count): | ||
def setup_cache(model): | ||
if model.config.kv_cache_type == "paged": | ||
cache_state = model.cache.allocate( | ||
page_count=hp.context_length // llama_config.block_seq_stride | ||
|
@@ -152,16 +175,21 @@ def repack_cache(cache, shard_dim): | |
return [SplitPrimitiveTensor(ts=c, shard_dim=shard_dim) for c in cache] | ||
|
||
def generate_batch_prefill(bs: int): | ||
tokens = torch.empty(bs, 64, dtype=torch.int64) | ||
seq_lens = torch.empty(bs, dtype=torch.int64) | ||
seq_block_ids = torch.empty(bs, 4, dtype=torch.int64) | ||
block_dim = torch.export.Dim( | ||
"block", max=(hp.context_length - 1) // llama_config.block_seq_stride | ||
) | ||
# torch.export.Dim would make min at least 2 | ||
block_dim_min = 2 | ||
block_dim_max = ceildiv(hp.context_length, llama_config.block_seq_stride) - 1 | ||
block_dim = torch.export.Dim("block", min=block_dim_min, max=block_dim_max) | ||
sl_dim = llama_config.block_seq_stride * block_dim | ||
seq_block_ids = torch.empty(bs, block_dim_min, dtype=torch.int64) | ||
tokens = torch.empty( | ||
bs, | ||
seq_block_ids.shape[1] * llama_config.block_seq_stride, | ||
dtype=torch.int64, | ||
) | ||
seq_lens = torch.empty(bs, dtype=torch.int64) | ||
|
||
cache, cache_shard_dim, cache_dynamic_shapes, arg_affinities = setup_cache( | ||
model, llama_config.tensor_parallelism_size | ||
model | ||
) | ||
|
||
# We need to offset the indices for the cache | ||
|
@@ -174,7 +202,7 @@ def generate_batch_prefill(bs: int): | |
"tokens": {1: sl_dim}, | ||
"seq_lens": {}, | ||
"seq_block_ids": {1: block_dim}, | ||
"cs": cache_dynamic_shapes, | ||
"cache_state": cache_dynamic_shapes, | ||
} | ||
|
||
print(f"Exporting prefill_bs{bs}") | ||
|
@@ -186,55 +214,42 @@ def generate_batch_prefill(bs: int): | |
strict=args.strict, | ||
arg_device=arg_affinities, | ||
) | ||
def _(model, tokens, seq_lens, seq_block_ids, cs): | ||
def _(model, tokens, seq_lens, seq_block_ids, cache_state): | ||
if ( | ||
model.config.tensor_parallelism_size == 1 | ||
and model.config.kv_cache_type == "direct" | ||
): | ||
cache_tensors = torch.unbind(cs) | ||
else: | ||
cache_tensors = cs | ||
|
||
sl = tokens.shape[1] | ||
input_mask = model.input_mask(seq_lens, sl) | ||
attention_mask = model.attention_mask(input_mask) | ||
|
||
if llama_config.tensor_parallelism_size != 1: | ||
shard_count = llama_config.tensor_parallelism_size | ||
|
||
tokens = ops.replicate(tokens, count=shard_count) | ||
attention_mask = ops.replicate(attention_mask, count=shard_count) | ||
seq_block_ids = ops.replicate(seq_block_ids, count=shard_count) | ||
|
||
cache_tensors = repack_cache(cs, cache_shard_dim) | ||
cache_state = torch.unbind(cache_state) | ||
if model.config.tensor_parallelism_size != 1: | ||
cache_state = repack_cache(cache_state, cache_shard_dim) | ||
|
||
logits = model.prefill( | ||
tokens, | ||
attention_mask=attention_mask, | ||
return model.prefill_from_seq_lens( | ||
tokens=tokens, | ||
seq_lens=seq_lens, | ||
seq_block_ids=seq_block_ids, | ||
cache_state=cache_tensors, | ||
cache_state=cache_state, | ||
) | ||
|
||
if llama_config.tensor_parallelism_size != 1: | ||
logits = ops.unshard(logits) | ||
|
||
return logits | ||
|
||
def generate_batch_decode(bs: int): | ||
tokens = torch.ones(bs, 1, dtype=torch.int64) | ||
seq_lens = torch.ones(bs, dtype=torch.int64) | ||
start_positions = torch.ones(bs, dtype=torch.int64) | ||
seq_block_ids = torch.zeros(bs, 4, dtype=torch.int64) | ||
block_dim = torch.export.Dim( | ||
"block", max=(hp.context_length - 1) // llama_config.block_seq_stride | ||
# torch.export.Dim would make min at least 2 | ||
block_dim_min = 2 | ||
block_dim_max = ceildiv(hp.context_length, llama_config.block_seq_stride) - 1 | ||
block_dim = torch.export.Dim("block", min=block_dim_min, max=block_dim_max) | ||
tokens = torch.empty( | ||
bs, | ||
1, | ||
dtype=torch.int64, | ||
) | ||
seq_lens = torch.empty(bs, dtype=torch.int64) | ||
start_positions = torch.ones(bs, dtype=torch.int64) | ||
seq_block_ids = torch.empty(bs, block_dim_min, dtype=torch.int64) | ||
|
||
( | ||
cache_state, | ||
cache_shard_dim, | ||
cache_dynamic_shapes, | ||
arg_affinities, | ||
) = setup_cache(model, llama_config.tensor_parallelism_size) | ||
) = setup_cache(model) | ||
|
||
# We need to offset the indices for the cache | ||
arg_affinities = {key + 4: arg_affinities[key] for key in arg_affinities} | ||
|
@@ -274,34 +289,21 @@ def _( | |
seq_block_ids, | ||
cache_state, | ||
): | ||
input_mask = model.input_mask( | ||
seq_lens, seq_block_ids.shape[1] * model.cache.block_seq_stride | ||
) | ||
attention_mask = model.decode_attention_mask(input_mask) | ||
|
||
if llama_config.tensor_parallelism_size != 1: | ||
shard_count = llama_config.tensor_parallelism_size | ||
|
||
tokens = ops.replicate(tokens, count=shard_count) | ||
attention_mask = ops.replicate(attention_mask, count=shard_count) | ||
start_positions = ops.replicate(start_positions, count=shard_count) | ||
seq_block_ids = ops.replicate(seq_block_ids, count=shard_count) | ||
|
||
if ( | ||
model.config.tensor_parallelism_size == 1 | ||
and model.config.kv_cache_type == "direct" | ||
): | ||
cache_state = torch.unbind(cache_state) | ||
if model.config.tensor_parallelism_size != 1: | ||
cache_state = repack_cache(cache_state, cache_shard_dim) | ||
|
||
logits = model.decode( | ||
tokens, | ||
attention_mask=attention_mask, | ||
return model.decode_from_seq_lens( | ||
tokens=tokens, | ||
seq_lens=seq_lens, | ||
start_positions=start_positions, | ||
seq_block_ids=seq_block_ids, | ||
cache_state=cache_state, | ||
) | ||
|
||
if llama_config.tensor_parallelism_size != 1: | ||
logits = ops.unshard(logits) | ||
|
||
return logits | ||
|
||
bsizes = [] | ||
for bs in args.bs: | ||
generate_batch_prefill(bs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use a map rather than string manipulations. It is as simple as