Skip to content

Commit

Permalink
remove transpose in kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Nov 4, 2024
1 parent edb1e87 commit e93cf55
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 24 deletions.
12 changes: 1 addition & 11 deletions src/nanotron/fp8/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,18 @@
@torch.no_grad()
def fp8_matmul_kernel(
mat_a: FP8Tensor,
transpose_a: bool,
mat_b: FP8Tensor,
transpose_b: bool,
output,
use_split_accumulator: bool,
accumulate: bool,
accum_qtype: torch.dtype,
# TODO(xrsrke): remove this flag
is_backward: bool = False,
recipe=None,
) -> torch.Tensor:
# from nanotron.fp8.constants import _empty_tensor, workspace

assert (
mat_a.device != "cpu" and mat_b.device != "cpu"
), "The tensors must be on a CUDA device in order to use the FP8 kernel!!"
), "The tensors must be on a CUDA device in order to use FP8!!"
# assert isinstance(accum_qtype, DTypes)
assert isinstance(accum_qtype, torch.dtype)

Expand All @@ -44,9 +40,7 @@ def fp8_matmul_kernel(
raise ValueError(f"Unsupported accumulation dtype: {accum_qtype}")

_empty_tensor = torch.Tensor()

workspace = torch.empty(33_554_432, dtype=torch.int8, device=device)
# accumulate = False

# NOTE: currently TE don't support adding bias in FP8
# along with matmul, it only takes an empty bias
Expand All @@ -62,10 +56,6 @@ def fp8_matmul_kernel(
TE_CONFIG_TRANSPOSE_B = False
SCALE = AMAX = _empty_tensor

# if is_backward is False:
# mat_a = tex.fp8_transpose(mat_a, mat_a_fp8_meta.te_dtype) if transpose_a is False else mat_a
# mat_b = tex.fp8_transpose(mat_b, mat_b_fp8_meta.te_dtype) if transpose_b is True else mat_b

tex.te_gemm(
mat_a,
mat_a_fp8_meta.inverse_scale,
Expand Down
13 changes: 0 additions & 13 deletions src/nanotron/fp8/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,11 @@ def forward(
output = fp8_matmul_kernel(
# NOTE: that works
mat_a=weight,
transpose_a=True,
mat_b=fp8_input,
transpose_b=False,
output=accum_output,
use_split_accumulator=recipe.split_accumulator.output,
accumulate=recipe.accumulate.output,
accum_qtype=recipe.accum_dtype,
recipe=recipe,
)
return output, phony

Expand Down Expand Up @@ -167,11 +164,9 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
fp8_input, fp8_weight = ctx.saved_tensors
recipe = ctx.recipe
recipe = cast(FP8LinearRecipe, recipe)
# accum_qtype = ctx.accum_qtype

fp8_input = cast(FP8Tensor, fp8_input)
fp8_weight = cast(FP8Tensor, fp8_weight)
# grad_output = grad_output.contiguous()

ctx.metadatas = cast(FP8LinearMeta, ctx.metadatas)
if ctx.metadatas.input_grad is None:
Expand All @@ -187,7 +182,6 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[

if ctx.is_input_require_grad:
transposed_fp8_weight = fp8_weight.transpose_fp8()

grad_input_temp = torch.empty(
fp8_grad_output.shape[0],
transposed_fp8_weight.shape[0],
Expand All @@ -196,15 +190,11 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
)
grad_input = fp8_matmul_kernel(
mat_a=transposed_fp8_weight,
transpose_a=True,
mat_b=fp8_grad_output,
transpose_b=False,
output=grad_input_temp,
use_split_accumulator=recipe.split_accumulator.input_grad,
accum_qtype=recipe.accum_dtype,
accumulate=recipe.accumulate.input_grad,
# is_backward=True,
recipe=recipe,
)
grad_input.__debug_is_from_fp8 = True
else:
Expand All @@ -222,14 +212,11 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
)
grad_weight = fp8_matmul_kernel(
mat_a=transposed_fp8_input,
transpose_a=True,
mat_b=transposed_fp8_grad_output,
transpose_b=False,
output=grad_weight_temp,
use_split_accumulator=recipe.split_accumulator.weight_grad,
accumulate=recipe.accumulate.weight_grad,
accum_qtype=recipe.accum_dtype,
recipe=recipe,
)

if ctx.is_input_require_grad:
Expand Down

0 comments on commit e93cf55

Please sign in to comment.