Skip to content

Commit

Permalink
remove unnecessary .transpose in fp8 linear backward
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Nov 1, 2024
1 parent 39a4960 commit 1ddc44c
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/nanotron/fp8/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,14 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
# grad_weight = grad_weight.T.contiguous()
# orig_shape = grad_weight.shape
# grad_weight = grad_weight.contiguous().t().contiguous().view(-1).contiguous().reshape(orig_shape)
grad_weight = grad_weight.T
orig_shape = grad_weight.shape
grad_weight = grad_weight.t().view(-1).reshape(orig_shape)

# grad_weight = grad_weight.T
# orig_shape = grad_weight.shape
# grad_weight = grad_weight.t().view(-1).reshape(orig_shape)

# NOTE: works
# grad_weight = grad_weight.reshape(grad_weight.T.shape)
grad_weight = grad_weight.reshape(grad_weight.shape[::-1])

# NOTE: if use gradient accumulation, then directly keep the high precision weights for later accumulate
if constants.CONFIG is not None and (
Expand Down

0 comments on commit 1ddc44c

Please sign in to comment.