Skip to content

Commit

Permalink
Added experimental mm padding for cache behavior
Browse files Browse the repository at this point in the history
Padding to cache lines can significantly improve performance by
separating out cache lines between iterations.
  • Loading branch information
rsuderman committed Jan 14, 2025
1 parent 04d383b commit 3a134ca
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 7 deletions.
21 changes: 21 additions & 0 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ def main():
help="Enables strictness during export",
action="store_true",
)
parser.add_argument(
"--experimental-mm-cache-size",
help="Experimental padding for matmuls to cache padding using cache size",
type=int,
default=None,
)
parser.add_argument(
"--experimental-mm-cache-sets",
help="Experimental padding for matmuls to cache padding using using number of cache sets",
type=int,
default=None,
)

cli.add_quantization_options(parser)
cli.add_model_options(parser)
Expand All @@ -75,6 +87,13 @@ def main():
else 1
)

if (args.experimental_mm_cache_size == None) != (
args.experimental_mm_cache_sets == None
):
raise NotImplementedError(
f"Both values need to be set for experimental cache padding"
)

llama_config = LlamaModelConfig(
hp,
tensor_parallelism_size=tensor_parallelism_size,
Expand All @@ -83,6 +102,8 @@ def main():
kv_cache_type="paged",
attention_kernel=args.attention_kernel,
block_seq_stride=args.block_seq_stride,
experimental_mm_cache_size=args.experimental_mm_cache_size,
experimental_mm_cache_sets=args.experimental_mm_cache_sets,
)
llama_config.fake_quant = args.fake_quant

Expand Down
5 changes: 5 additions & 0 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,11 @@ class LlamaModelConfig:
# the program and not.
static_tables: bool = True

# Experimental matmul padding configuration designed to allow matmuls to pad
# to cache line configurations:
experimental_mm_cache_size: Optional[int] = None
experimental_mm_cache_sets: Optional[int] = None


@dataclass
class T5Config:
Expand Down
29 changes: 26 additions & 3 deletions sharktank/sharktank/layers/ffn_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,38 @@ def __init__(
theta: Theta,
is_gated: bool = True,
activation_fn: Callable[[AnyTensor], AnyTensor] = F.silu,
experimental_mm_cache_size: Optional[int] = None,
experimental_mm_cache_sets: Optional[int] = None,
):
super().__init__(theta)

self.is_gated = is_gated
self.activation_fn = activation_fn
if self.is_gated:
self.add_module("ffn_gate", LinearLayer(theta("ffn_gate")))
self.add_module("ffn_up", LinearLayer(theta("ffn_up")))
self.add_module("ffn_down", LinearLayer(theta("ffn_down")))
self.add_module(
"ffn_gate",
LinearLayer(
theta("ffn_gate"),
experimental_mm_cache_size=experimental_mm_cache_size,
experimental_mm_cache_sets=experimental_mm_cache_sets,
),
)
self.add_module(
"ffn_up",
LinearLayer(
theta("ffn_up"),
experimental_mm_cache_size=experimental_mm_cache_size,
experimental_mm_cache_sets=experimental_mm_cache_sets,
),
)
self.add_module(
"ffn_down",
LinearLayer(
theta("ffn_down"),
experimental_mm_cache_size=experimental_mm_cache_size,
experimental_mm_cache_sets=experimental_mm_cache_sets,
),
)

def forward(
self,
Expand Down
22 changes: 22 additions & 0 deletions sharktank/sharktank/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DynamicScaledQuantizer,
QuantizedTensor,
QuantizerTensor,
SplitPrimitiveTensor,
StaticScaledQuantizer,
TensorScaledLayout,
PlanarQuantizedTensor,
Expand Down Expand Up @@ -43,12 +44,16 @@ def __init__(
weight_name: str = "weight",
bias_name: str = "bias",
fake_quant: bool = False,
experimental_mm_cache_size: Optional[int] = None,
experimental_mm_cache_sets: Optional[int] = None,
):
super().__init__(theta)
self._simulate_native_quant = True
self.weight = self.theta_tensor(weight_name)
self.bias = None
self.fake_quant = fake_quant
self.experimental_mm_cache_size = experimental_mm_cache_size
self.experimental_mm_cache_sets = experimental_mm_cache_sets
if bias_name in self.theta.keys:
self.bias = self.theta_tensor(bias_name)

Expand Down Expand Up @@ -77,6 +82,23 @@ def forward(self, x):
elif qdq_input is not None:
x = qdq_input.quantize(x).unpack().dequant()

# TODO - This should be removed once the compiler supports it:
if (
self.experimental_mm_cache_sets is not None
and self.experimental_mm_cache_size is not None
):
contract = x.shape[-1]
if isinstance(x, SplitPrimitiveTensor):
contract = contract // x.shard_count
element_size = torch.finfo(x.dtype).bits
cache_line_size = self.experimental_mm_cache_size * 8 // element_size
cache_size = self.experimental_mm_cache_sets * cache_line_size
if contract % cache_size == 0:
x = ops.pad(x, [0, cache_line_size], constant=0, per_shard=True)
weight = ops.pad(
weight, [0, cache_line_size], constant=0, per_shard=True
)

y = ops.linear(x, weight, bias)

# Unconditionally dequantize.
Expand Down
34 changes: 30 additions & 4 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def __init__(
attention_scale: Optional[float] = None,
softcap: Optional[float] = None,
fake_quant: Optional[bool] = True,
experimental_mm_cache_size: Optional[int] = None,
experimental_mm_cache_sets: Optional[int] = None,
):
super().__init__(theta)

Expand All @@ -58,16 +60,40 @@ def __init__(
"attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon)
)
self.add_module(
"attn_q", LinearLayer(theta("attn_q"), fake_quant=self.fake_quant)
"attn_q",
LinearLayer(
theta("attn_q"),
fake_quant=self.fake_quant,
experimental_mm_cache_size=experimental_mm_cache_size,
experimental_mm_cache_sets=experimental_mm_cache_sets,
),
)
self.add_module(
"attn_k", LinearLayer(theta("attn_k"), fake_quant=self.fake_quant)
"attn_k",
LinearLayer(
theta("attn_k"),
fake_quant=self.fake_quant,
experimental_mm_cache_size=experimental_mm_cache_size,
experimental_mm_cache_sets=experimental_mm_cache_sets,
),
)
self.add_module(
"attn_v", LinearLayer(theta("attn_v"), fake_quant=self.fake_quant)
"attn_v",
LinearLayer(
theta("attn_v"),
fake_quant=self.fake_quant,
experimental_mm_cache_size=experimental_mm_cache_size,
experimental_mm_cache_sets=experimental_mm_cache_sets,
),
)
self.add_module(
"attn_output", LinearLayer(theta("attn_output"), fake_quant=self.fake_quant)
"attn_output",
LinearLayer(
theta("attn_output"),
fake_quant=self.fake_quant,
experimental_mm_cache_size=experimental_mm_cache_size,
experimental_mm_cache_sets=experimental_mm_cache_sets,
),
)
self.cache_quantizer = None
if "kv_cache" in theta.keys:
Expand Down
10 changes: 10 additions & 0 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
self.activation_dtype = config.activation_dtype
self.use_hf = config.use_hf
self.attention_kernel = config.attention_kernel
self.experimental_mm_cache_sets = config.experimental_mm_cache_sets
self.experimental_mm_cache_size = config.experimental_mm_cache_size

self.add_module(
"token_embedding",
Expand Down Expand Up @@ -115,6 +117,8 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
rms_epsilon=hp.attention_layer_norm_rms_epsilon,
attention_kernel=self.attention_kernel,
fake_quant=self.fake_quant,
experimental_mm_cache_sets=self.experimental_mm_cache_sets,
experimental_mm_cache_size=self.experimental_mm_cache_size,
)
for n in range(hp.block_count)
]
Expand Down Expand Up @@ -287,6 +291,8 @@ def __init__(
rms_epsilon: float,
attention_kernel: str = "decomposed",
fake_quant: bool = True,
experimental_mm_cache_size: Optional[int] = None,
experimental_mm_cache_sets: Optional[int] = None,
):
super().__init__(theta)
self.add_module(
Expand All @@ -301,12 +307,16 @@ def __init__(
rms_epsilon=rms_epsilon,
attention_kernel=attention_kernel,
fake_quant=fake_quant,
experimental_mm_cache_size=experimental_mm_cache_size,
experimental_mm_cache_sets=experimental_mm_cache_sets,
),
)
self.add_module(
"ffn",
FFN(
theta=theta,
experimental_mm_cache_size=experimental_mm_cache_size,
experimental_mm_cache_sets=experimental_mm_cache_sets,
),
)
self.add_module(
Expand Down
5 changes: 5 additions & 0 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,11 @@ def module_register_buffer_default(
return module.register_buffer(name, unbox_tensor(tensor))


@pad.override(Tensor)
def pad_default(x, pad, constant, **kwargs) -> Tensor:
return torch.nn.functional.pad(unbox_tensor(x), pad=pad, value=constant)


@repeat.override(Tensor)
def repeat_default(input: Union[Tensor, PrimitiveTensor], *sizes: List[int]) -> Tensor:
return unbox_tensor(input).repeat(*sizes)
Expand Down
23 changes: 23 additions & 0 deletions sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,29 @@ def module_register_buffer_sharded(
setattr(module, name, tensor)


@pad.override(ReplicatedTensor)
def pad_default(x, pad, constant, **kwargs) -> Tensor:
shards = [
torch.nn.functional.pad(unbox_tensor(shard), pad=pad, value=constant)
for shard in x.shards
]
shard_dim = x.shard_dim
return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim)


@pad.override(ShardedTensor)
def pad_default(x, pad, constant, per_shard) -> Tensor:
shard_dim = x.shard_dim
if per_shard and (pad[shard_dim * 2] != 0 or pad[shard_dim * 2 + 1] != 0):
raise ValueError(f"Shard dimension {shard_dim} cannot be non-zero")

padded_shards = [
torch.nn.functional.pad(unbox_tensor(shard), pad=pad, value=constant)
for shard in x.shards
]
return SplitPrimitiveTensor(ts=padded_shards, shard_dim=shard_dim)


@permute.override(SplitPrimitiveTensor)
def permute_split(tensor: SplitPrimitiveTensor, dims: List[int]):
permuted_shards = [permute(shard, dims) for shard in tensor.shards]
Expand Down
28 changes: 28 additions & 0 deletions sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"matmul",
"mean",
"module_register_buffer",
"pad",
"permute",
"rms_norm",
"repeat",
Expand Down Expand Up @@ -737,6 +738,33 @@ def _mean_trampoline(
d.fail(tensors)


@overridable
def pad(
x: AnyTensor,
pad: List[int],
constant: float = 0.0,
) -> AnyTensor:
"""See torch.pad"""
raise NotImplementedError


@pad.trampoline
def _pad_trampoline(
d: SignatureDispatcher,
x: AnyTensor,
pad: List[int],
constant: float = 0.0,
**kwargs,
) -> AnyTensor:
tensors = (x,)
for override in d.find_overrides(tensors):
result = override(x, pad, constant, **kwargs)
if result is not NotImplemented:
return override, result
else:
d.fail(tensors)


@overridable
def module_register_buffer(
module: torch.nn.Module, name: str, tensor: AnyTensor
Expand Down

0 comments on commit 3a134ca

Please sign in to comment.