Skip to content

Commit

Permalink
fix: using te and fsdp leads to multiple device found error (#1453)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 authored Nov 19, 2024
1 parent a617503 commit f206afa
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,13 @@ def _resursively_swap_linear_layers_for_te(module: torch.nn.Module) -> None:

if isinstance(m, torch.nn.Linear):
has_bias = m.bias is not None
new_linear = te.Linear(m.in_features, m.out_features, bias=has_bias, device=device)
# Pass device as str (as there is a bug in TransformerEngine's handling of torch.device)
new_linear = te.Linear(m.in_features, m.out_features, bias=has_bias, device=str(device))
setattr(module, n, new_linear)

if swap_layernorm and isinstance(m, torch.nn.LayerNorm):
new_layernorm = te.LayerNorm(m.normalized_shape[0], eps=m.eps, device=device)
# Pass device as str (as there is a bug in TransformerEngine's handling of torch.device)
new_layernorm = te.LayerNorm(m.normalized_shape[0], eps=m.eps, device=str(device))
setattr(module, n, new_layernorm)

initial_params_cnt = parameters_cnt(model)
Expand Down Expand Up @@ -366,11 +368,6 @@ def __init__(
self.model = self.init_model()
print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")

if self.use_te_fp8_autocast:
is_wo_layernorm = self.low_precision_mode == "fp8-delayed-te-wo_layernorm"
swap_linear_layers_for_te(self.model, device, swap_layernorm=not is_wo_layernorm)
self.model.to(torch.bfloat16)

# Setup the distributed algorithm choices
if distributed_first := (self.compile in ("eager", "inductor") or "dynamo" in self.compile):
self.model = self.setup_distributed(self.model)
Expand Down Expand Up @@ -407,8 +404,14 @@ def init_model(self):
init_device = torch.device("meta") if self.distributed_mode in FSDP_MODES else self.device
with init_device:
model = GPT(self.config)
model.to(dtype=torch.bfloat16)

# Handle fp8 related Linear layer swapping (for torchao or TransformerEngine)
model = self._torchao_fp8_handler.convert_model_to_fp8(model)
if self.use_te_fp8_autocast:
is_wo_layernorm = self.low_precision_mode == "fp8-delayed-te-wo_layernorm"
swap_linear_layers_for_te(model, init_device, swap_layernorm=not is_wo_layernorm)

model.to(dtype=torch.bfloat16)
return model

def setup_distributed(self, model):
Expand Down

0 comments on commit f206afa

Please sign in to comment.