diff --git a/src/nanotron/fp8/linear.py b/src/nanotron/fp8/linear.py index 011dbe3e..5037d697 100644 --- a/src/nanotron/fp8/linear.py +++ b/src/nanotron/fp8/linear.py @@ -248,9 +248,12 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[ assert grad_weight.dtype == recipe.accum_dtype # TODO(xrsrke): maintain a persistence metadata across training - grad_weight = grad_weight.T.contiguous() + # grad_weight = grad_weight.T.contiguous() + # orig_shape = grad_weight.shape + # grad_weight = grad_weight.contiguous().t().contiguous().view(-1).contiguous().reshape(orig_shape) + grad_weight = grad_weight.T orig_shape = grad_weight.shape - grad_weight = grad_weight.contiguous().t().contiguous().view(-1).contiguous().reshape(orig_shape) + grad_weight = grad_weight.t().view(-1).reshape(orig_shape) # NOTE: if use gradient accumulation, then directly keep the high precision weights for later accumulate if constants.CONFIG is not None and (