Skip to content

Commit

Permalink
Add dtype, fix RMS norm for FP16 (#8641)
Browse files Browse the repository at this point in the history
* Add dtype, fix RMS norm for FP16

* up

* up

* Update llama_transformer.py
  • Loading branch information
metascroy authored Feb 26, 2025
1 parent 2be4e94 commit 5a594a7
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 92 deletions.
81 changes: 38 additions & 43 deletions examples/apple/coreml/llama/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# pyre-strict

import argparse
import json

import sys

Expand All @@ -20,10 +19,11 @@
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
from executorch.extension.export_util.utils import export_to_edge, save_pte_program
from executorch.exir.program._program import to_edge_with_preserved_ops
from executorch.extension.export_util.utils import save_pte_program

sys.path.insert(0, ".")
from llama_transformer import InputManager, ModelArgs, Transformer
from llama_transformer import InputManager, load_model


class SplitLinearModule(torch.nn.Module):
Expand Down Expand Up @@ -141,42 +141,23 @@ def main() -> None:
default=8,
help="Maximum number of splits to divide linear layers",
)
parser.add_argument(
"--dtype",
type=str,
default="fp16",
)

export_args = parser.parse_args()
params_path = export_args.params
checkpoint_path = export_args.checkpoint

# Load model args
with open(params_path, "r") as f:
params = json.loads(f.read())

args = ModelArgs(
max_seq_len=export_args.max_seq_length,
generate_full_logits=False,
model = load_model(
export_args.checkpoint,
export_args.params,
max_seq_length=export_args.max_seq_length,
use_cache_list=export_args.use_cache_list,
**params,
)

with torch.device("meta"):
model = Transformer(args)

checkpoint = torch.load(
checkpoint_path, map_location="cpu", mmap=True, weights_only=True
)
if "model" in checkpoint:
checkpoint = checkpoint["model"]

missing, unexpected = model.load_state_dict(
checkpoint,
strict=False,
assign=True,
)
print("Missing keys: ", missing)
print("Unexpected keys: ", unexpected)

float_dtype = torch.float16 # dtype for model/inputs
model.eval()
model.to(float_dtype)
float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[
export_args.dtype
] # dtype for model/inputs

if export_args.embedding_quantize:
bitwidth, group_size = export_args.embedding_quantize.split(",")
Expand All @@ -197,7 +178,8 @@ def main() -> None:
model, export_args.target_split_size, export_args.max_splits
)

model = model.to(float_dtype)
model.eval()
model.to(float_dtype)

op_linear_quantizer_config = None
if export_args.coreml_quantize == "b4w":
Expand All @@ -217,7 +199,10 @@ def main() -> None:

compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16]
minimum_deployment_target=ct.target.iOS18,
compute_precision=ct.precision(ct.precision.FLOAT16.value),
compute_precision={
torch.float16: ct.precision.FLOAT16,
torch.float32: ct.precision.FLOAT32,
}[float_dtype],
compute_unit=ct.ComputeUnit.CPU_AND_NE,
model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16]
op_linear_quantizer_config=op_linear_quantizer_config,
Expand All @@ -232,11 +217,11 @@ def main() -> None:
)

input_manager = InputManager(
n_layers=args.n_layers,
max_batch_size=args.max_batch_size,
n_kv_heads=args.n_kv_heads,
max_seq_length=args.max_seq_len,
head_dim=args.head_dim,
n_layers=model.params.n_layers,
max_batch_size=model.params.max_batch_size,
n_kv_heads=model.params.n_kv_heads,
max_seq_length=model.params.max_seq_len,
head_dim=model.params.head_dim,
use_cache_list=export_args.use_cache_list,
seq_length=export_args.seq_length,
dtype=float_dtype,
Expand All @@ -245,10 +230,20 @@ def main() -> None:
)
example_inputs = input_manager.get_inputs(tokens=[0])

edge_manager = export_to_edge(
ep = torch.export.export(
model,
example_inputs,
edge_compile_config=EdgeCompileConfig(
)
print("Exported program")
print(ep)

edge_manager = to_edge_with_preserved_ops(
ep,
preserve_ops=[
torch.ops.aten.scaled_dot_product_attention.default,
torch.ops.aten.linalg_vector_norm.default,
],
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_skip_type_promotion=(float_dtype == torch.float16),
_skip_dim_order=True,
Expand Down
122 changes: 93 additions & 29 deletions examples/apple/coreml/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import torch
import torch.nn.functional as F

from executorch.examples.models.llama.llama_transformer import RMSNorm

from executorch.examples.models.llama.rope import (
hf_apply_rotary_emb,
hf_precompute_freqs_cis,
Expand All @@ -25,29 +23,6 @@
from torch import nn


# These are just to prevent to_edge from decomposing SDPA
# A better method is to use the to_edge_transform_and_lower API for CoreML
# and not decompose SDPA
@torch.library.custom_op("coreml::sdpa", mutates_args=())
def sdpa(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
) -> torch.Tensor:
"""Same as F.scaled_dot_product_attention, but with custom op to avoid lowering during dialect conversion."""
return torch.ops.aten.scaled_dot_product_attention.default(
q, k, v, attn_mask=attn_mask
)


@torch.library.register_fake("coreml::sdpa")
def _(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
) -> torch.Tensor:
"""Fake implementation with the right output shape, which is required for torch.compile/export/fx tracing."""
expected_shape = list(q.shape)
expected_shape[-1] = v.shape[-1]
return q.new_empty(expected_shape)


def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
Expand Down Expand Up @@ -121,6 +96,63 @@ def __post_init__(self):
self.head_dim = self.dim // self.n_heads


class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
# CoreML ignores casts to FP32, so existing implementation of RMSNorm was not stable
# We instead use (x * sqrt(n)) / norm(x, dim=-1)
# Using torch.norm and preserving this op in CoreML improves stability
# Note, we ignore eps, but could add it by using torch.norm(torch.concat(x, sqrt(n*eps))) in the denominator
# In future, we want to add CoreML support for the functional RMSNorm op
# We have yet to do large scale evaluations on the numeric stability of this solution, but note that
# it appears better than what exists currently (removing FP32 casts and using FP16)
rms_norm_eps0 = (
x * torch.sqrt(torch.tensor(self.dim, dtype=x.dtype))
) / torch.linalg.vector_norm(x, dim=-1, keepdim=True)
return rms_norm_eps0

def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x)
return output * self.weight


class Rope(torch.nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
Expand Down Expand Up @@ -304,12 +336,11 @@ def forward(
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)

output = torch.ops.coreml.sdpa(q, k, v, attn_mask)

output = torch.ops.aten.scaled_dot_product_attention.default(
q, k, v, attn_mask=attn_mask
)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

output = self.wo(output)

return output, new_k, new_v


Expand Down Expand Up @@ -413,6 +444,39 @@ def forward(
return logits, k_out, v_out


def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list):
import json

with open(params_path, "r") as f:
params = json.loads(f.read())

args = ModelArgs(
max_seq_len=max_seq_length,
generate_full_logits=False,
use_cache_list=use_cache_list,
**params,
)

with torch.device("meta"):
model = Transformer(args)

checkpoint = torch.load(
checkpoint_path, map_location="cpu", mmap=True, weights_only=True
)
if "model" in checkpoint:
checkpoint = checkpoint["model"]

missing, unexpected = model.load_state_dict(
checkpoint,
strict=False,
assign=True,
)
print("Missing keys: ", missing)
print("Unexpected keys: ", unexpected)

return model


class InputManager:
def __init__(
self,
Expand Down
8 changes: 7 additions & 1 deletion examples/apple/coreml/llama/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ This directory contains ANE-friendly Llama models.

Export model with:
```
python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w
python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w --dtype fp16
```

(Note the script should be run from the executorch/examples/apple/coreml/llama directory.)
Expand All @@ -17,6 +17,12 @@ Run model with:
python run.py -m /path/to/model.pte -t /path/to/tokenizer.model --prompt "Once upon a time,"
```

The runner can also be used to run an eager model model to compare with CoreML numerics (--use_eager). In this case, you must specify:
* --checkpoint
* --dtype
* --max_seq_length
* --seq_length

(Note the script should be run from the executorch/examples/apple/coreml/llama directory.)


Expand Down
Loading

0 comments on commit 5a594a7

Please sign in to comment.