Skip to content

Commit

Permalink
remove uncessary transpose input in the fwd pass, and contiguous weig…
Browse files Browse the repository at this point in the history
…ht initialization
  • Loading branch information
xrsrke committed Nov 1, 2024
1 parent 1ddc44c commit c827594
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions src/nanotron/fp8/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def __init__(
# TODO(xrsrke): don't fixed dtype, take it from the FP8 recipe
# DTypes.FP8E4M3
weight_data = self.weight.data
orig_w_shape = weight_data.shape
weight_data = weight_data.contiguous().view(-1).contiguous().reshape(orig_w_shape)
# orig_w_shape = weight_data.shape
# weight_data = weight_data.contiguous().view(-1).contiguous().reshape(orig_w_shape)
quant_w = FP8Parameter(weight_data, dtype=recipe.weight.dtype, interval=recipe.weight.interval)
assert quant_w.dtype in [torch.uint8, torch.int8], f"got {self.weight.data.dtype}"
self.weight = quant_w
Expand Down Expand Up @@ -111,8 +111,10 @@ def forward(

sync_amax_in_input = fp8_config.sync_amax_in_input

orig_input_shape = input.shape
input = input.contiguous().view(-1).contiguous().view(orig_input_shape)
# orig_input_shape = input.shape
# input = input.contiguous().view(-1).contiguous().view(orig_input_shape)

# input = input.contiguous()

if metadatas.input is None:
fp8_input = FP8Tensor(
Expand All @@ -127,12 +129,13 @@ def forward(
ctx.name = name
ctx.recipe = recipe

accum_output = output.contiguous()
# accum_output = output.contiguous()
accum_output = output
# accum_output = torch.zeros(output.shape, dtype=torch.float16, device="cuda")

assert fp8_input.data.is_contiguous()
assert weight.data.is_contiguous()
assert accum_output.is_contiguous()
# assert fp8_input.data.is_contiguous()
# assert weight.data.is_contiguous()
# assert accum_output.is_contiguous()

# dist.monitored_barrier(wait_all_ranks=True)

Expand Down Expand Up @@ -296,4 +299,5 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
# NOTE: sanity check
assert isinstance(fp8_weight.grad, FP8Tensor)

return grad_input.contiguous(), None, None, None, None, None, None
# return grad_input.contiguous(), None, None, None, None, None, None
return grad_input, None, None, None, None, None, None

0 comments on commit c827594

Please sign in to comment.