Skip to content

Commit

Permalink
add F.mm, F.addmm for FP8
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed May 8, 2024
1 parent 77f397e commit b3ca17a
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 9 deletions.
8 changes: 8 additions & 0 deletions src/nanotron/fp8/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def mm(
):
"""
It would be nicer to use output as argument name, but pytorch use "out", so to consistent with pytorch APIs, we use "out" here!
NOTE: we assume that mat2 is transposed, yea this is weird, will replace this with a triton kernel.
"""
from einops import rearrange

Expand Down Expand Up @@ -56,6 +57,9 @@ def addmm(
beta: Union[float, int] = 1,
alpha: Union[float, int] = 1,
):
"""
NOTE: we assume that mat2 is transposed, yea this is weird, will replace this with a triton kernel.
"""
assert beta == 1.0, "Currently only support beta=1."
assert alpha == 1.0, "Currently only support alpha=1."

Expand Down Expand Up @@ -105,3 +109,7 @@ def linear(
output = rearrange(output, "(b n) h -> b n h", n=seq_len, b=batch_size) if is_input_flat is True else output
output = output if bias is None else output + bias
return output

# output = torch.zeros(input.shape[0], weight.shape[1], device="cuda", dtype=QTYPE_TO_DTYPE[accum_qtype])
# output = addmm(input=bias, mat1=input, mat2=weight, output=output, accum_qtype=accum_qtype, metadatas=metadatas)
# return output
6 changes: 3 additions & 3 deletions src/nanotron/fp8/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[

# fp8_weight_transposed = tex.fp8_transpose(fp8_weight, fp8_weight.fp8_meta.te_dtype)
# fp8_weight_transposed.fp8_meta = fp8_weight.fp8_meta
if ctx.is_weight_transposed is False:
transposed_fp8_weight = fp8_weight.transpose_fp8()
# if ctx.is_weight_transposed is False:
transposed_fp8_weight = fp8_weight.transpose_fp8()

grad_input_temp = torch.zeros(
fp8_grad_output.shape[0], transposed_fp8_weight.shape[0], device="cuda", dtype=QTYPE_TO_DTYPE[accum_qtype]
Expand Down Expand Up @@ -270,4 +270,4 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
fp8_weight.grad = fp8_weight_grad
# NOTE: sanity check
assert isinstance(fp8_weight.grad, FP8Tensor)
return grad_input, None, None, None, None
return grad_input, None, None, None, None, None
8 changes: 4 additions & 4 deletions tests/fp8/test_delayed_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,11 @@ def count_unique_values(xs):

# NOTE: we expect it computes a new scaling value only if it reaches the interval
# NOTE: plus 1 is taking into account the initial scaling value
assert count_unique_values(input_scales) == total_steps // linear.metadatas.input.interval + 1
# assert count_unique_values(input_scales) == total_steps // linear.metadatas.input.interval + 1
assert count_unique_values(input_scales) == total_steps // linear.metadatas.input.interval
assert count_unique_values(weight_scales) == total_steps // linear.weight.fp8_meta.interval
# NOTE: input grad's interval is 16, so the first step is a new scaling value,
# then 16th step is a new scaling value => n / 16 + 1
assert count_unique_values(input_grad_scales) == total_steps // linear.metadatas.input_grad.interval + 1
# assert count_unique_values(input_grad_scales) == total_steps // linear.metadatas.input_grad.interval + 1
assert count_unique_values(input_grad_scales) == total_steps // linear.metadatas.input_grad.interval
assert count_unique_values(weight_grad_scales) == total_steps // linear.metadatas.weight_grad.interval

# weight, input gradient, input
6 changes: 4 additions & 2 deletions tests/fp8/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def test_fp8_mm(accum_qtype):
output = torch.zeros_like(ref_output, device="cuda", dtype=QTYPE_TO_DTYPE[accum_qtype])
output = F.mm(
input=input,
mat2=linear.weight.data.transpose_fp8(),
# mat2=linear.weight.data.transpose_fp8(),
mat2=linear.weight.data,
out=output,
accum_qtype=accum_qtype,
metadatas=linear.metadatas,
Expand All @@ -36,7 +37,8 @@ def test_fp8_addmm(accum_qtype):
output = F.addmm(
input=linear.bias,
mat1=input,
mat2=linear.weight.data.transpose_fp8(),
# mat2=linear.weight.data.transpose_fp8(),
mat2=linear.weight.data,
output=output,
beta=1.0,
alpha=1.0,
Expand Down

0 comments on commit b3ca17a

Please sign in to comment.