From 1ddc44c6ac8645ac051db34c30107b39821f8da8 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 1 Nov 2024 13:56:30 +0000 Subject: [PATCH] remove unnecessary .transpose in fp8 linear backward --- src/nanotron/fp8/linear.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/nanotron/fp8/linear.py b/src/nanotron/fp8/linear.py index 5037d697..2e06f1b6 100644 --- a/src/nanotron/fp8/linear.py +++ b/src/nanotron/fp8/linear.py @@ -251,9 +251,14 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[ # grad_weight = grad_weight.T.contiguous() # orig_shape = grad_weight.shape # grad_weight = grad_weight.contiguous().t().contiguous().view(-1).contiguous().reshape(orig_shape) - grad_weight = grad_weight.T - orig_shape = grad_weight.shape - grad_weight = grad_weight.t().view(-1).reshape(orig_shape) + + # grad_weight = grad_weight.T + # orig_shape = grad_weight.shape + # grad_weight = grad_weight.t().view(-1).reshape(orig_shape) + + # NOTE: works + # grad_weight = grad_weight.reshape(grad_weight.T.shape) + grad_weight = grad_weight.reshape(grad_weight.shape[::-1]) # NOTE: if use gradient accumulation, then directly keep the high precision weights for later accumulate if constants.CONFIG is not None and (