Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce CausalLMModel intefrace and add IREE numerics test for Llama 3.1 8B FP16 TP8 #375

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 71 additions & 69 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,20 @@

from sharktank.layers import *
from sharktank.types import *
from sharktank.utils.math import ceildiv

# TODO: Should be using a base class with the protocol supported.
from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1
from ..models.llama.sharding import shard_theta
from ..models.mixtral.mixtral import *
from ..models.grok.grok import *
from .. import ops


def main():
def dtype_from_str(s: str) -> torch.dtype:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a map rather than string manipulations. It is as simple as

{
  "fp8":torch.fp8,
  "f32":torch.f32,
....
}

parts = s.split(".", maxsplit=1)
assert parts != 2 or parts[0] != "torch"
return getattr(torch, parts[1])


def main(raw_args: list[str] | None = None):
from ..utils import cli

parser = cli.create_parser()
Expand All @@ -43,6 +47,12 @@ def main():
type=lambda arg: [int(bs) for bs in arg.split(",")],
default="4",
)
parser.add_argument(
"--block-seq-stride",
help="Block sequence stride for a paged KV cache.",
type=int,
default=LlamaModelConfig.default_block_seq_stride,
)
parser.add_argument(
"--verbose",
help="Include verbose logging",
Expand All @@ -59,8 +69,18 @@ def main():
default="decomposed",
choices=["decomposed", "torch"],
)
parser.add_argument(
"--attention-dtype",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide a justification for adding these in the first place? They should be inferred from the data types of the functions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is one test that uses fp32.

type=str,
default=str(LlamaModelConfig.default_attention_dtype),
)
parser.add_argument(
"--activation-dtype",
type=str,
default=str(LlamaModelConfig.default_activation_dtype),
)

args = cli.parse(parser)
args = cli.parse(parser, args=raw_args)
dataset_type = cli.get_input_data_files(args)
dataset_type = "irpa" if "irpa" in dataset_type else "gguf"
dataset = cli.get_input_dataset(args)
Expand All @@ -73,11 +93,14 @@ def main():
)
llama_config = LlamaModelConfig(
hp,
block_seq_stride=args.block_seq_stride,
tensor_parallelism_size=tensor_parallelism_size,
use_hf=False,
static_tables=False, # Rely on the compiler for hoisting tables.
kv_cache_type="direct" if args.bs == [1] else "paged",
attention_kernel=args.attention_kernel,
attention_dtype=dtype_from_str(args.attention_dtype),
activation_dtype=dtype_from_str(args.activation_dtype),
)

if llama_config.hp.expert_count:
Expand Down Expand Up @@ -110,7 +133,7 @@ def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]):

fxb = FxProgramsBuilder(model)

def setup_cache(model, shard_count):
def setup_cache(model):
if model.config.kv_cache_type == "paged":
cache_state = model.cache.allocate(
page_count=hp.context_length // llama_config.block_seq_stride
Expand Down Expand Up @@ -152,16 +175,21 @@ def repack_cache(cache, shard_dim):
return [SplitPrimitiveTensor(ts=c, shard_dim=shard_dim) for c in cache]

def generate_batch_prefill(bs: int):
tokens = torch.empty(bs, 64, dtype=torch.int64)
seq_lens = torch.empty(bs, dtype=torch.int64)
seq_block_ids = torch.empty(bs, 4, dtype=torch.int64)
block_dim = torch.export.Dim(
"block", max=(hp.context_length - 1) // llama_config.block_seq_stride
)
# torch.export.Dim would make min at least 2
block_dim_min = 2
block_dim_max = ceildiv(hp.context_length, llama_config.block_seq_stride) - 1
block_dim = torch.export.Dim("block", min=block_dim_min, max=block_dim_max)
sl_dim = llama_config.block_seq_stride * block_dim
seq_block_ids = torch.empty(bs, block_dim_min, dtype=torch.int64)
tokens = torch.empty(
bs,
seq_block_ids.shape[1] * llama_config.block_seq_stride,
dtype=torch.int64,
)
seq_lens = torch.empty(bs, dtype=torch.int64)

cache, cache_shard_dim, cache_dynamic_shapes, arg_affinities = setup_cache(
model, llama_config.tensor_parallelism_size
model
)

# We need to offset the indices for the cache
Expand All @@ -174,7 +202,7 @@ def generate_batch_prefill(bs: int):
"tokens": {1: sl_dim},
"seq_lens": {},
"seq_block_ids": {1: block_dim},
"cs": cache_dynamic_shapes,
"cache_state": cache_dynamic_shapes,
}

print(f"Exporting prefill_bs{bs}")
Expand All @@ -186,55 +214,42 @@ def generate_batch_prefill(bs: int):
strict=args.strict,
arg_device=arg_affinities,
)
def _(model, tokens, seq_lens, seq_block_ids, cs):
def _(model, tokens, seq_lens, seq_block_ids, cache_state):
if (
model.config.tensor_parallelism_size == 1
and model.config.kv_cache_type == "direct"
):
cache_tensors = torch.unbind(cs)
else:
cache_tensors = cs

sl = tokens.shape[1]
input_mask = model.input_mask(seq_lens, sl)
attention_mask = model.attention_mask(input_mask)

if llama_config.tensor_parallelism_size != 1:
shard_count = llama_config.tensor_parallelism_size

tokens = ops.replicate(tokens, count=shard_count)
attention_mask = ops.replicate(attention_mask, count=shard_count)
seq_block_ids = ops.replicate(seq_block_ids, count=shard_count)

cache_tensors = repack_cache(cs, cache_shard_dim)
cache_state = torch.unbind(cache_state)
if model.config.tensor_parallelism_size != 1:
cache_state = repack_cache(cache_state, cache_shard_dim)

logits = model.prefill(
tokens,
attention_mask=attention_mask,
return model.prefill_from_seq_lens(
tokens=tokens,
seq_lens=seq_lens,
seq_block_ids=seq_block_ids,
cache_state=cache_tensors,
cache_state=cache_state,
)

if llama_config.tensor_parallelism_size != 1:
logits = ops.unshard(logits)

return logits

def generate_batch_decode(bs: int):
tokens = torch.ones(bs, 1, dtype=torch.int64)
seq_lens = torch.ones(bs, dtype=torch.int64)
start_positions = torch.ones(bs, dtype=torch.int64)
seq_block_ids = torch.zeros(bs, 4, dtype=torch.int64)
block_dim = torch.export.Dim(
"block", max=(hp.context_length - 1) // llama_config.block_seq_stride
# torch.export.Dim would make min at least 2
block_dim_min = 2
block_dim_max = ceildiv(hp.context_length, llama_config.block_seq_stride) - 1
block_dim = torch.export.Dim("block", min=block_dim_min, max=block_dim_max)
tokens = torch.empty(
bs,
1,
dtype=torch.int64,
)
seq_lens = torch.empty(bs, dtype=torch.int64)
start_positions = torch.ones(bs, dtype=torch.int64)
seq_block_ids = torch.empty(bs, block_dim_min, dtype=torch.int64)

(
cache_state,
cache_shard_dim,
cache_dynamic_shapes,
arg_affinities,
) = setup_cache(model, llama_config.tensor_parallelism_size)
) = setup_cache(model)

# We need to offset the indices for the cache
arg_affinities = {key + 4: arg_affinities[key] for key in arg_affinities}
Expand Down Expand Up @@ -274,34 +289,21 @@ def _(
seq_block_ids,
cache_state,
):
input_mask = model.input_mask(
seq_lens, seq_block_ids.shape[1] * model.cache.block_seq_stride
)
attention_mask = model.decode_attention_mask(input_mask)

if llama_config.tensor_parallelism_size != 1:
shard_count = llama_config.tensor_parallelism_size

tokens = ops.replicate(tokens, count=shard_count)
attention_mask = ops.replicate(attention_mask, count=shard_count)
start_positions = ops.replicate(start_positions, count=shard_count)
seq_block_ids = ops.replicate(seq_block_ids, count=shard_count)

if (
model.config.tensor_parallelism_size == 1
and model.config.kv_cache_type == "direct"
):
cache_state = torch.unbind(cache_state)
if model.config.tensor_parallelism_size != 1:
cache_state = repack_cache(cache_state, cache_shard_dim)

logits = model.decode(
tokens,
attention_mask=attention_mask,
return model.decode_from_seq_lens(
tokens=tokens,
seq_lens=seq_lens,
start_positions=start_positions,
seq_block_ids=seq_block_ids,
cache_state=cache_state,
)

if llama_config.tensor_parallelism_size != 1:
logits = ops.unshard(logits)

return logits

bsizes = []
for bs in args.bs:
generate_batch_prefill(bs)
Expand Down
29 changes: 10 additions & 19 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class TorchGenerator:

def __init__(
self,
model: PagedLlamaModelV1,
model: CausalLMModel,
tokenizer: InferenceTokenizer,
page_cache_size: int = 128,
# Need to look at the model more for this.
Expand Down Expand Up @@ -162,17 +162,14 @@ def compute_prefill_logits(

def prefill(self):
model = self.parent.model
attention_mask = model.attention_mask(
model.input_mask(self.seq_lens, self.token_ids.shape[1])
)
seq_block_ids_tensor = self.pad_block_ids()
print(f":: Invoke prefill:")
trace_tensor("prefill.token_ids", self.token_ids)
trace_tensor("prefill.seq_lens", self.seq_lens)
trace_tensor("prefill.seq_block_ids", seq_block_ids_tensor)
trace_tensor("prefill.attention_mask", attention_mask)
logits = model.prefill(
self.logits = model.prefill_from_seq_lens(
self.token_ids,
attention_mask=attention_mask,
seq_lens=self.seq_lens,
seq_block_ids=seq_block_ids_tensor,
cache_state=self.cache_state,
)
Expand All @@ -181,7 +178,7 @@ def prefill(self):
# TODO: Normalize the output of extract_tokens_from_logits into
# tensor [bs, 1].
tokens = torch.tensor(
model.extract_tokens_from_logits(logits, self.seq_lens)
model.extract_tokens_from_logits(self.logits, self.seq_lens)
).unsqueeze(1)
print(f":: Prefill results:\n{tokens.tolist()}")
self.add_result_token(tokens)
Expand All @@ -194,28 +191,22 @@ def decode(self):
self.allocate_seq_block_ids()
# TODO: Allocate more blocks on overflow.
seq_block_ids_tensor = self.pad_block_ids()
decode_attention_mask = model.decode_attention_mask(
model.input_mask(
self.seq_lens,
seq_block_ids_tensor.shape[1] * self.parent.block_seq_stride,
)
)
trace_tensor("decode.token_ids", self.next_tokens)
trace_tensor("decode.seq_lens", self.seq_lens)
trace_tensor("decode.start_positions", start_positions)
trace_tensor("decode.seq_block_ids", seq_block_ids_tensor)
trace_tensor("decode.attention_mask", decode_attention_mask)
logits = model.decode(
self.logits = model.decode_from_seq_lens(
self.next_tokens,
attention_mask=decode_attention_mask,
seq_lens=self.seq_lens,
start_positions=start_positions,
seq_block_ids=seq_block_ids_tensor,
cache_state=self.cache_state,
)
trace_tensor("decode.logits", logits)
trace_tensor("decode.logits", self.logits)
# TODO: Normalize the output of extract_tokens_from_logits into
# tensor [bs, 1].
tokens = torch.tensor(
model.extract_tokens_from_logits(logits, [1] * self.bs),
model.extract_tokens_from_logits(self.logits, [1] * self.bs),
device=self.parent.model.device,
).unsqueeze(1)
self.add_result_token(tokens)
Expand Down
5 changes: 4 additions & 1 deletion sharktank/sharktank/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from .base import BaseLayer, ThetaLayer
from .conv import Conv2DLayer
from .kv_cache import BaseKVCache, DirectKVCache, PagedKVCache
from .causal_llm import BaseCausalLMModel
from .causal_llm import (
CausalLMModel,
BaseCausalLMModel,
)
from .linear import LinearLayer
from .norm import RMSNormLayer
from .rotary_embedding import RotaryEmbeddingLayer
Expand Down
Loading
Loading