From 39a496095b7d580ec68376a0fd6db733d49927a9 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 1 Nov 2024 13:50:39 +0000 Subject: [PATCH] remove uncessary .contiguous() in fp8 backward --- src/nanotron/fp8/linear.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/nanotron/fp8/linear.py b/src/nanotron/fp8/linear.py index 011dbe3e..5037d697 100644 --- a/src/nanotron/fp8/linear.py +++ b/src/nanotron/fp8/linear.py @@ -248,9 +248,12 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[ assert grad_weight.dtype == recipe.accum_dtype # TODO(xrsrke): maintain a persistence metadata across training - grad_weight = grad_weight.T.contiguous() + # grad_weight = grad_weight.T.contiguous() + # orig_shape = grad_weight.shape + # grad_weight = grad_weight.contiguous().t().contiguous().view(-1).contiguous().reshape(orig_shape) + grad_weight = grad_weight.T orig_shape = grad_weight.shape - grad_weight = grad_weight.contiguous().t().contiguous().view(-1).contiguous().reshape(orig_shape) + grad_weight = grad_weight.t().view(-1).reshape(orig_shape) # NOTE: if use gradient accumulation, then directly keep the high precision weights for later accumulate if constants.CONFIG is not None and (