diff --git a/src/nanotron/fp8/linear.py b/src/nanotron/fp8/linear.py index 5037d697..2e06f1b6 100644 --- a/src/nanotron/fp8/linear.py +++ b/src/nanotron/fp8/linear.py @@ -251,9 +251,14 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[ # grad_weight = grad_weight.T.contiguous() # orig_shape = grad_weight.shape # grad_weight = grad_weight.contiguous().t().contiguous().view(-1).contiguous().reshape(orig_shape) - grad_weight = grad_weight.T - orig_shape = grad_weight.shape - grad_weight = grad_weight.t().view(-1).reshape(orig_shape) + + # grad_weight = grad_weight.T + # orig_shape = grad_weight.shape + # grad_weight = grad_weight.t().view(-1).reshape(orig_shape) + + # NOTE: works + # grad_weight = grad_weight.reshape(grad_weight.T.shape) + grad_weight = grad_weight.reshape(grad_weight.shape[::-1]) # NOTE: if use gradient accumulation, then directly keep the high precision weights for later accumulate if constants.CONFIG is not None and (