Skip to content

Improve FP6-LLM 2+4bit weight splitting + user API #279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/dtypes/test_float6_e3m2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])


class TestFp6(TestCase):
class TestFloat6E3M2(TestCase):

@parametrize("device", _DEVICES)
@parametrize("dtype", _DTYPES)
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_from_float6_e3m2_compile(self, device, no_bit_packing):
torch.testing.assert_close(actual, expected)


instantiate_parametrized_tests(TestFp6)
instantiate_parametrized_tests(TestFloat6E3M2)


if __name__ == "__main__":
Expand Down
99 changes: 99 additions & 0 deletions test/quantization/test_fp6_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import pytest
import torch
from torch import nn
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2
from torchao.quantization.fp6_llm import to_tc_float6_e3m2, from_tc_float6_e3m2, Fp6LlmLinear, convert_fp6_llm
from torchao.ops import prepack_fp6_weight


_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])


class TestFp6LlmLinear(TestCase):
@parametrize("device", _DEVICES)
def test_to_tc_float6_e3m2_correctness(self, device):
x = torch.randn(256, 64, device=device)

expected = prepack_fp6_weight(to_float6_e3m2(x.cpu()).view(torch.int32)).view(torch.uint8)
actual = to_tc_float6_e3m2(x)
torch.testing.assert_close(actual.view(-1).cpu(), expected.view(-1))

@parametrize("device", _DEVICES)
def test_to_tc_float6_e3m2_compile(self, device):
x = torch.randn(256, 64, device=device)

expected = to_tc_float6_e3m2(x)
actual = torch.compile(to_tc_float6_e3m2)(x)
torch.testing.assert_close(actual, expected)

@parametrize("device", _DEVICES)
def test_from_tc_float6_e3m2_correctness(self, device):
x = torch.randn(256, 64, device=device)
x = from_float6_e3m2(to_float6_e3m2(x)) # quantize and dequantize so that the values are exactly representable in FP6

actual = from_tc_float6_e3m2(to_tc_float6_e3m2(x), *x.shape)
torch.testing.assert_close(actual, x)

@parametrize("device", _DEVICES)
def test_from_tc_float6_e3m2_compile(self, device):
M, N = 256, 64
x = torch.randint(256, size=(M * N * 3 // 4,), dtype=torch.uint8, device=device)

expected = from_tc_float6_e3m2(x, M, N)
actual = torch.compile(from_tc_float6_e3m2)(x, M, N)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("leading_dims", [(4,), (2, 4)])
@parametrize("bias", [False, True])
def test_fp6_llm_linear_forward(self, bias, leading_dims):
OC, IC = 256, 64
device = "cuda"

linear = torch.nn.Linear(IC, OC, bias=bias, device=device)
fp6_linear = Fp6LlmLinear.from_float(linear)
assert (fp6_linear.bias is not None) == bias

x = torch.randn(*leading_dims, IC, device=device, dtype=torch.half)
fp6_linear(x)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("bias", [False, True])
def test_fp6_llm_linear_compile(self, bias):
N, OC, IC = 4, 256, 64
device = "cuda"

linear = torch.nn.Linear(IC, OC, bias=bias, device=device)
fp6_linear = Fp6LlmLinear.from_float(linear)

x = torch.randn(N, IC, device=device, dtype=torch.half)
expected = fp6_linear(x)
actual = torch.compile(fp6_linear)(x)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_convert_fp6_llm(self):
device = "cuda"
model = nn.Sequential(nn.Linear(64, 256, bias=False), nn.Linear(256, 256)).to(device)
convert_fp6_llm(model)

assert isinstance(model[0], Fp6LlmLinear)
assert model[0].bias is None
assert isinstance(model[1], Fp6LlmLinear)
assert model[1].bias is not None

x = torch.randn(4, 64, device=device)
model(x)


instantiate_parametrized_tests(TestFp6LlmLinear)


if __name__ == "__main__":
run_tests()
4 changes: 2 additions & 2 deletions torchao/csrc/cuda/fp6_llm/fp6_linear.cu
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats,
int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
int num_out_channels = _weights.size(0);
assert( num_in_channels%64 == 0 );
assert( (num_in_channels/16*3) == _weights.size(1) ); // Making sure the K dimension is matched.
TORCH_CHECK(num_in_channels%64 == 0, "Expected in_features to be a multiple of 64, but received ", num_in_channels);
TORCH_CHECK((num_in_channels/16*3) == _weights.size(1)); // Making sure the K dimension is matched.
//
int M = num_out_channels;
int K = num_in_channels;
Expand Down
160 changes: 160 additions & 0 deletions torchao/quantization/fp6_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from typing import Optional

import torch
from torch import nn, Tensor
from torchao.dtypes.float6_e3m2 import FLOAT6_E3M2_MAX, to_float6_e3m2, from_float6_e3m2
from torchao.ops import fp16act_fp6weight_linear


def _pack_2bit(x: Tensor) -> Tensor:
return (x[..., ::4] << 6) | (x[..., 1::4] << 4) | (x[..., 2::4] << 2) | x[..., 3::4]


def _unpack_2bit(x: Tensor) -> Tensor:
return torch.stack([x >> 6, (x >> 4) & 0b11, (x >> 2) & 0b11, x & 0b11], dim=-1).flatten(-2)


def _pack_4bit(x: Tensor) -> Tensor:
return (x[..., ::2] << 4) | x[..., 1::2]


def _unpack_4bit(x: Tensor) -> Tensor:
return torch.stack([x >> 4, x & 0b1111], dim=-1).flatten(-2)


# this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing
# https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h
def _to_tc_float6_e3m2_original(tensor: Tensor) -> Tensor:
assert tensor.ndim == 2
M, N = tensor.shape
assert (M % 64 == 0) and (N % 64 == 0)

tensor_fp6 = to_float6_e3m2(tensor, no_bit_packing=True)

# Pass 1 from original code
tensor_fp6 = tensor_fp6.view(M // 64, 4, 2, 8, N // 16, 2, 8)
tensor_fp6 = tensor_fp6.permute(0, 4, 1, 5, 2, 3, 6)
tensor_fp6 = tensor_fp6.reshape(-1, 32, 2)
tensor_fp6 = tensor_fp6.permute(1, 0, 2)
tensor_fp6 = tensor_fp6.flatten()

tensor_2bit = _pack_2bit((tensor_fp6 >> 4) & 0b11)
tensor_4bit = _pack_4bit(tensor_fp6 & 0b1111)

# Pass 2 from original code
tensor_2bit = tensor_2bit.view(32, -1, 4).permute(1, 0, 2).flip(2)
tensor_4bit = tensor_4bit.view(32, -1, 4).permute(1, 0, 2).flip(2)

# Pass 3 from original code
# BitInterleaving_2bit
# the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32
# while we still unpack/pack the values from/to uint8
tensor_2bit = _unpack_2bit(tensor_2bit).view(-1, 16)
tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]]
tensor_2bit = tensor_2bit[:, [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14]]
tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]]
tensor_2bit = _pack_2bit(tensor_2bit).view(-1)

# BitInterleaving_4bit
# the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32
# while we still unpack/pack the values from/to uint8
tensor_4bit = _unpack_4bit(tensor_4bit).view(-1, 8)
tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]]
tensor_4bit = tensor_4bit[:, [1, 5, 3, 7, 0, 4, 2, 6]]
tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]]
tensor_4bit = _pack_4bit(tensor_4bit).view(-1)

return torch.cat([tensor_2bit, tensor_4bit], dim=0)


# more optimized version of _to_tc_float6_e3m2_original() by merging ops
# https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h
def to_tc_float6_e3m2(tensor: Tensor) -> Tensor:
assert tensor.ndim == 2
M, N = tensor.shape
assert (M % 64 == 0) and (N % 64 == 0)

tensor_fp6 = to_float6_e3m2(tensor, no_bit_packing=True)
tensor_fp6 = tensor_fp6.view(M // 64, 2, 2, 2, 8, N // 16, 2, 8)
tensor_fp6 = tensor_fp6.flip(3)

tensor_2bit = (tensor_fp6 >> 4) & 0b11
tensor_2bit = tensor_2bit.permute(0, 5, 1, 4, 7, 3, 2, 6)
tensor_2bit = _pack_2bit(tensor_2bit.flatten())

tensor_4bit = tensor_fp6 & 0b1111
tensor_4bit = tensor_4bit.permute(0, 5, 1, 2, 4, 7, 3, 6)
tensor_4bit = _pack_4bit(tensor_4bit.flatten())

return torch.cat([tensor_2bit, tensor_4bit], dim=0)


def from_tc_float6_e3m2(tensor: Tensor, M: int, N: int, dtype: torch.dtype = torch.float32) -> Tensor:
assert tensor.ndim == 1
assert (M % 64 == 0) and (N % 64 == 0)
size_2bit = M * N // 4
size_4bit = M * N // 2
assert tensor.numel() == size_2bit + size_4bit

tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit])

tensor_2bit = _unpack_2bit(tensor_2bit)
tensor_2bit = tensor_2bit.view(M // 64, N // 16, 2, 8, 8, 2, 2, 2)
tensor_2bit = tensor_2bit.permute(0, 2, 6, 5, 3, 1, 7, 4)

tensor_4bit = _unpack_4bit(tensor_4bit)
tensor_4bit = tensor_4bit.view(M // 64, N // 16, 2, 2, 8, 8, 2, 2)
tensor_4bit = tensor_4bit.permute(0, 2, 3, 6, 4, 1, 7, 5)

tensor_fp6 = (tensor_2bit << 4) | tensor_4bit
tensor_fp6 = tensor_fp6.flip(3).reshape(M, N)
return from_float6_e3m2(tensor_fp6, no_bit_packing=True, dtype=dtype)


class Fp6LlmLinear(nn.Module):
"""FP6-LLM Linear layer as described in https://arxiv.org/pdf/2401.14112.
"""

def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None) -> None:
super().__init__()
self.register_buffer("weight", weight)
self.register_buffer("scales", scales)
self.register_buffer("bias", bias)
self.out_features = weight.shape[0]
self.in_features = weight.shape[1] * 16 // 3

def forward(self, x: Tensor) -> Tensor:
# TODO: splitK map
out = fp16act_fp6weight_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=1)
if self.bias is not None:
out = out + self.bias
return out.view(*x.shape[:-1], self.out_features).to(x.dtype)

@classmethod
def from_float(cls, linear: nn.Linear):
assert (linear.in_features % 64 == 0) and (linear.out_features % 256 == 0)

fp32_weight = linear.weight.detach().float()
scales = fp32_weight.abs().amax(1) / FLOAT6_E3M2_MAX
scales[scales == 0.0] = 1.0 # avoid 0 scale

tc_fp6_weight = to_tc_float6_e3m2(fp32_weight / scales.view(-1, 1))
tc_fp6_weight = tc_fp6_weight.view(linear.out_features, -1).view(torch.int32)

bias = linear.bias.detach().half() if linear.bias is not None else None
return cls(tc_fp6_weight, scales.half(), bias)

def extra_repr(self) -> str:
return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'


def convert_fp6_llm(model: nn.Module, skip_fqn_list: Optional[list[str]] = None, cur_fqn: str = "") -> None:
for name, child in model.named_children():
new_fqn = name if cur_fqn == "" else f"{cur_fqn}.{name}"

if ((skip_fqn_list is None) or (new_fqn not in skip_fqn_list)) and (isinstance(child, nn.Linear)):
if (child.in_features % 64 == 0) and (child.out_features % 256 == 0):
new_child = Fp6LlmLinear.from_float(child)
setattr(model, name, new_child)
else:
convert_fp6_llm(child, skip_fqn_list, new_fqn)
Loading