From c07682797ffaad8fc4de54419cd7e50cf0511c12 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Sat, 3 Aug 2024 00:04:31 -0400 Subject: [PATCH 01/54] fix the bug of deepspeed sequence parallel working with batch size larger than 1 --- deepspeed/sequence/layer.py | 95 ++++++++++++++++++++++--------------- 1 file changed, 58 insertions(+), 37 deletions(-) diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index f17cfa883cc6..e2c845284541 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -12,48 +12,66 @@ from deepspeed.accelerator import get_accelerator -def post_all2all(transpose, res_shape): +def post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, seq_len, num_head, head_dim): def post_func(input): - if transpose: - input = input.transpose(0, 2).contiguous() - input = input.reshape(res_shape) - return input + if batch_dim_idx == 0: + # b, s, n, h + if scatter_idx < 2: + output = input.permute(1, 2, 0, 3, 4).contiguous() + output = output.reshape(bs, seq_len // seq_world_size, seq_world_size * num_head, head_dim).contiguous() + else: + output = input.permute(1, 0, 2, 3, 4).contiguous() + output = output.reshape(bs, seq_world_size * seq_len, num_head // seq_world_size, head_dim).contiguous() + else: + # s, b, n, h + if scatter_idx < 2: + output = input.permute(1, 2, 0, 3, 4).contiguous() + output = output.reshape(seq_len // seq_world_size, bs, seq_world_size * num_head, head_dim).contiguous() + else: + output = input.reshape(seq_len * seq_world_size, bs, num_head // seq_world_size, head_dim).contiguous() + return output return post_func - -def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False, handle=None, type=None): +def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None): seq_world_size = dist.get_world_size(group) - inp_shape = list(input.shape) - inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size + if batch_dim_idx == 0: + # b, s, n, h + if scatter_idx < 2: + bs, global_seq_len, num_local_head, head_dim = input.shape + input_t = input.reshape([bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, head_dim]).contiguous() + input_t = input_t.permute(1, 0, 2, 3, 4).contiguous() + else: + bs, local_seq_len, num_total_head, head_dim = input.shape + assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!" + input_t = input.reshape([bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, head_dim]).contiguous() + input_t = input_t.permute(2, 0, 1, 3, 4).contiguous() + else: + # s, b, n, h + if scatter_idx < 2: + global_seq_len, bs, num_local_head, head_dim = input.shape + input_t = input.reshape([seq_world_size, global_seq_len // seq_world_size, bs, num_local_head, head_dim]).contiguous() + else: + local_seq_len, bs, num_total_head, head_dim = input.shape + assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!" + input_t = input.reshape([local_seq_len, bs, seq_world_size, num_total_head // seq_world_size, head_dim]).contiguous() + input_t = input_t.permute(2, 0, 1, 3, 4).contiguous() + if scatter_idx < 2: - input_t = input.reshape( - [seq_world_size, inp_shape[scatter_idx]] + \ - inp_shape[scatter_idx + 1:] - ).contiguous() + post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, global_seq_len, num_local_head, head_dim) else: - # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! - input_t = input.reshape( - [-1, seq_world_size, inp_shape[scatter_idx]] + \ - inp_shape[scatter_idx + 1:] - ).transpose(0, 1).contiguous() + post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, local_seq_len, num_total_head, head_dim) output = torch.empty_like(input_t) work = dist.all_to_all_single(output, input_t, group=group, async_op=async_op) - res_shape=( inp_shape[: gather_idx] + \ - [inp_shape[gather_idx] * seq_world_size,] + \ - inp_shape[gather_idx + 1:]) - transpose = True if scatter_idx < 2 else False - post_all2all_fun = post_all2all(transpose, res_shape) - if async_op: if type in ('dq', 'dk'): handle[type + '_work'] = work handle[type + '_grad'] = output handle[type + '_post_all2all_func'] = post_all2all_fun - return output.view(res_shape) + return output res = post_all2all_fun(output) return res @@ -67,6 +85,7 @@ def forward(ctx: Any, input: Tensor, scatter_idx: int, gather_idx: int, + batch_dim_idx: int, stream=None, handle=None, type=None, @@ -77,14 +96,15 @@ def forward(ctx: Any, ctx.stream = stream ctx.handle = handle ctx.type = type + ctx.batch_dim_idx = batch_dim_idx if ctx.handle is None: - res = single_all_to_all(input, scatter_idx, gather_idx, group, False) + res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False) else: # overlap communication path if not is_fwd and type == 'o': assert ctx.stream != None - res = single_all_to_all(input, scatter_idx, gather_idx, group, False) + res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False) get_accelerator().current_stream().wait_stream(ctx.stream) del ctx.stream.activation_buffer_list # The computation of d o_weight can overlap with the communication of d o_input @@ -92,15 +112,15 @@ def forward(ctx: Any, elif not is_fwd and type in ('q', 'k'): # Achieve communication overlap by pipelining the matrix computation and communication of dq, dk, and dv type = 'd' + type - res = single_all_to_all(input, scatter_idx, gather_idx, group, True, handle, type) + res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, True, handle, type) elif is_fwd and type in ('q', 'k'): # Achieve communication overlap by pipelining the matrix computation and communication of q, k, and v type = 'fwd_' + type - res = single_all_to_all(input, scatter_idx, gather_idx, group, False, handle, type) + res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False, handle, type) else: - res = single_all_to_all(input, scatter_idx, gather_idx, group, False) + res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False) return res @@ -108,8 +128,8 @@ def forward(ctx: Any, def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: return (None, - _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.stream, ctx.handle, - ctx.type, False), None, None, None, None, None, None) + _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.batch_dim_idx, ctx.stream, ctx.handle, + ctx.type, False), None, None, None, None, None, None, None) class DistributedAttention(torch.nn.Module): @@ -148,13 +168,14 @@ def layer_sync(self, layer): if self.sp_overlap_comm and hasattr(layer, 'done_event'): self.dafult_stream.wait_event(layer.done_event) - def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor: + def forward(self, query: Tensor, key: Tensor, value: Tensor, batch_dim_idx: int, *args: Any, **kwargs) -> Tensor: """ forward Arguments: query (Tensor): query input to the layer key (Tensor): key input to the layer value (Tensor): value input to the layer + batch_dim_idx (int): indicating which dim is batch args: other args Returns: @@ -179,15 +200,15 @@ def pre_hook_fun(grad): return pre_hook_fun self.layer_sync(query) - query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx, None, + query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx, batch_dim_idx, None, self.overlap_handles, 'q') self.layer_sync(key) - key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, None, self.overlap_handles, + key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, batch_dim_idx, None, self.overlap_handles, 'k') if self.sp_overlap_comm: self.dafult_stream.wait_stream(self.sp_stream) - value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, None, + value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, batch_dim_idx, None, self.overlap_handles, 'v') if self.sp_overlap_comm: @@ -205,7 +226,7 @@ def pre_hook_fun(grad): context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) - output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, self.sp_stream, + output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, batch_dim_idx, self.sp_stream, self.overlap_handles, 'o') #out e.g., [s/p::h] From ed34e89ff5b4d79e69f9d557822135c985143539 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Wed, 7 Aug 2024 14:51:20 -0400 Subject: [PATCH 02/54] apply yapf formatting --- deepspeed/__init__.py | 4 +- deepspeed/autotuning/autotuner.py | 12 +++--- deepspeed/elasticity/elastic_agent.py | 4 +- deepspeed/module_inject/replace_module.py | 7 ++-- deepspeed/runtime/config.py | 4 +- deepspeed/runtime/eigenvalue.py | 4 +- deepspeed/runtime/pipe/engine.py | 7 ++-- deepspeed/runtime/utils.py | 4 +- deepspeed/sequence/layer.py | 40 ++++++++++++-------- tests/unit/runtime/zero/test_zero_context.py | 6 +-- 10 files changed, 52 insertions(+), 40 deletions(-) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index fe0043547860..b9c44e076252 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -152,8 +152,8 @@ def initialize(args=None, if hasattr(args, "deepscale_config") and args.deepscale_config is not None: logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************") if hasattr(args, "deepspeed_config"): - assert (args.deepspeed_config is - None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config" + assert (args.deepspeed_config + is None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config" args.deepspeed_config = args.deepscale_config args.deepscale_config = None diff --git a/deepspeed/autotuning/autotuner.py b/deepspeed/autotuning/autotuner.py index dfd195bc37eb..a72b3c951e97 100755 --- a/deepspeed/autotuning/autotuner.py +++ b/deepspeed/autotuning/autotuner.py @@ -248,8 +248,8 @@ def mp_size(self): return self.autotuning_config.mp_size def max_train_micro_batch_size_per_gpu(self): - if self.max_train_batch_size( - ) and self.max_train_batch_size() > 0: # if the user specifies a max_train_batch_size + if self.max_train_batch_size() and self.max_train_batch_size( + ) > 0: # if the user specifies a max_train_batch_size max_train_micro_batch_size = self.max_train_batch_size() * self.mp_size() // ( self.exp_num_gpus * self.exp_num_nodes) # gradient accumulation steps >=1 return min(self.autotuning_config.max_train_micro_batch_size_per_gpu, max_train_micro_batch_size) @@ -964,8 +964,8 @@ def get_min_max_micro_batch_size(self, stage, min_micro_batch_size, calculated_m low = mid + 1 self.update_records(tuning_space_name, exp, metric_val, 1) used_micro_batch_sizes.append(mid) - if prev_metric_val and ( - (metric_val - prev_metric_val) / prev_metric_val) < METRIC_PERCENT_DIFF_CONST: + if prev_metric_val and ((metric_val - prev_metric_val) / + prev_metric_val) < METRIC_PERCENT_DIFF_CONST: logger.info(f"performance plateaus at mbs = {low}") break prev_metric_val = metric_val @@ -1026,8 +1026,8 @@ def get_tuning_micro_batch_size_list(self, min_micro_batch_size, max_micro_batch # NUM_GPUS=$(( ${NUM_WORKERS} * ${NUM_GPUS_PER_WORKER} )) # DP_SIZE=$(( ${NUM_GPUS} / (${PP_SIZE} * ${MP_SIZE}) )) # GRAD_ACC_STEPS=$(( ${TARGET_GLOBAL_BATCH_SIZE} / (${BATCH_SIZE} * ${DP_SIZE}) )) - if self.max_train_batch_size( - ) and self.max_train_batch_size() > 0: # if the user specifies a max_train_batch_size + if self.max_train_batch_size() and self.max_train_batch_size( + ) > 0: # if the user specifies a max_train_batch_size max_train_batch_size_per_gpu = self.max_train_batch_size() * self.mp_size() // (self.exp_num_gpus * self.exp_num_nodes) else: diff --git a/deepspeed/elasticity/elastic_agent.py b/deepspeed/elasticity/elastic_agent.py index 039b999dfeca..0a8cf09db8a4 100644 --- a/deepspeed/elasticity/elastic_agent.py +++ b/deepspeed/elasticity/elastic_agent.py @@ -158,8 +158,8 @@ def _invoke_run(self, role: str = "default") -> RunResult: f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish.") self._exit_barrier() return run_result - elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED - } or len(participants) > len(rdzv_handler._state_holder.state.participants): + elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED} or len(participants) > len( + rdzv_handler._state_holder.state.participants): if self._remaining_restarts > 0: log.info(f"[{role}] Worker group {state.name}. " f"{self._remaining_restarts}/{spec.max_restarts} attempts left;" diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 64dc5479940c..29f8e1c22945 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -486,9 +486,10 @@ def conv2d_parallel_shard_weights(model, rank, world_size): if not dist.is_initialized() or dist.get_rank() == 0: print("Saving tp-sharded checkpoints") torch.save( - OrderedDict({k: v - for k, v in dict(replaced_module.state_dict()).items() - if transformer_name not in k}), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}') + OrderedDict({ + k: v + for k, v in dict(replaced_module.state_dict()).items() if transformer_name not in k + }), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}') dtype_reprs = { torch.float32: 'float32', diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index b49b4a8b6086..19f90b4d7c3b 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -1010,8 +1010,8 @@ def _do_error_check(self): self.gradient_accumulation_steps), "DeepSpeedConfig: {} is not defined".format(GRADIENT_ACCUMULATION_STEPS) if self.zero_enabled: - assert (self.zero_optimization_stage <= - ZeroStageEnum.max_stage), "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format( + assert (self.zero_optimization_stage + <= ZeroStageEnum.max_stage), "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format( ZeroStageEnum.max_stage) if self.fp16_master_weights_and_gradients: diff --git a/deepspeed/runtime/eigenvalue.py b/deepspeed/runtime/eigenvalue.py index df63854dd1ca..2361c21d756c 100755 --- a/deepspeed/runtime/eigenvalue.py +++ b/deepspeed/runtime/eigenvalue.py @@ -110,8 +110,8 @@ def compute_eigenvalue(self, module, device=None, scale=1.0): eigenvalue_current, eigenvalue_previous = 1., 0. while (i < self.max_iter) and abs(eigenvalue_current) > 0 and (abs( - (eigenvalue_current - eigenvalue_previous) / eigenvalue_current) >= - self.tol): # test convergence criteria + (eigenvalue_current - eigenvalue_previous) / eigenvalue_current) + >= self.tol): # test convergence criteria eigenvalue_previous = eigenvalue_current Hv = torch.autograd.grad(grads, params, grad_outputs=v, only_inputs=True, retain_graph=True) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index c627846b743c..b62d087cfbba 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -632,9 +632,10 @@ def _aggregate_total_loss(self): self.dp_group_loss = losses[0].clone().detach() agg_loss = losses[1].clone().detach() if additional_losses is not None: - self.agg_additional_losses = OrderedDict( - {name: losses[2 + i].clone().detach() - for i, name in enumerate(additional_losses.keys())}) + self.agg_additional_losses = OrderedDict({ + name: losses[2 + i].clone().detach() + for i, name in enumerate(additional_losses.keys()) + }) return agg_loss def set_dataloader(self, loader): diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 2c01c3475a70..efa35219c01a 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -257,8 +257,8 @@ def has_overflow(self, params, has_moe_params=None): elif self.mpu is not None: if self.deepspeed is not None: using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') - if (using_pipeline and self.deepspeed.pipeline_enable_backward_allreduce is False) or ( - not using_pipeline and self.deepspeed.enable_backward_allreduce is False): + if (using_pipeline and self.deepspeed.pipeline_enable_backward_allreduce + is False) or (not using_pipeline and self.deepspeed.enable_backward_allreduce is False): dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_data_parallel_group()) dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_model_parallel_group()) elif self.deepspeed is not None and self.deepspeed.enable_backward_allreduce is False: diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index e2c845284541..e809fe1118b5 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -19,49 +19,59 @@ def post_func(input): # b, s, n, h if scatter_idx < 2: output = input.permute(1, 2, 0, 3, 4).contiguous() - output = output.reshape(bs, seq_len // seq_world_size, seq_world_size * num_head, head_dim).contiguous() + output = output.reshape(bs, seq_len // seq_world_size, seq_world_size * num_head, + head_dim).contiguous() else: output = input.permute(1, 0, 2, 3, 4).contiguous() - output = output.reshape(bs, seq_world_size * seq_len, num_head // seq_world_size, head_dim).contiguous() + output = output.reshape(bs, seq_world_size * seq_len, num_head // seq_world_size, + head_dim).contiguous() else: # s, b, n, h if scatter_idx < 2: output = input.permute(1, 2, 0, 3, 4).contiguous() - output = output.reshape(seq_len // seq_world_size, bs, seq_world_size * num_head, head_dim).contiguous() + output = output.reshape(seq_len // seq_world_size, bs, seq_world_size * num_head, + head_dim).contiguous() else: output = input.reshape(seq_len * seq_world_size, bs, num_head // seq_world_size, head_dim).contiguous() return output return post_func + def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None): seq_world_size = dist.get_world_size(group) if batch_dim_idx == 0: # b, s, n, h if scatter_idx < 2: bs, global_seq_len, num_local_head, head_dim = input.shape - input_t = input.reshape([bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, head_dim]).contiguous() + input_t = input.reshape([bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, + head_dim]).contiguous() input_t = input_t.permute(1, 0, 2, 3, 4).contiguous() else: bs, local_seq_len, num_total_head, head_dim = input.shape assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!" - input_t = input.reshape([bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, head_dim]).contiguous() + input_t = input.reshape([bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, + head_dim]).contiguous() input_t = input_t.permute(2, 0, 1, 3, 4).contiguous() else: # s, b, n, h if scatter_idx < 2: global_seq_len, bs, num_local_head, head_dim = input.shape - input_t = input.reshape([seq_world_size, global_seq_len // seq_world_size, bs, num_local_head, head_dim]).contiguous() + input_t = input.reshape([seq_world_size, global_seq_len // seq_world_size, bs, num_local_head, + head_dim]).contiguous() else: local_seq_len, bs, num_total_head, head_dim = input.shape assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!" - input_t = input.reshape([local_seq_len, bs, seq_world_size, num_total_head // seq_world_size, head_dim]).contiguous() + input_t = input.reshape([local_seq_len, bs, seq_world_size, num_total_head // seq_world_size, + head_dim]).contiguous() input_t = input_t.permute(2, 0, 1, 3, 4).contiguous() if scatter_idx < 2: - post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, global_seq_len, num_local_head, head_dim) + post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, global_seq_len, num_local_head, + head_dim) else: - post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, local_seq_len, num_total_head, head_dim) + post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, local_seq_len, num_total_head, + head_dim) output = torch.empty_like(input_t) work = dist.all_to_all_single(output, input_t, group=group, async_op=async_op) @@ -128,8 +138,8 @@ def forward(ctx: Any, def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: return (None, - _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.batch_dim_idx, ctx.stream, ctx.handle, - ctx.type, False), None, None, None, None, None, None, None) + _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.batch_dim_idx, + ctx.stream, ctx.handle, ctx.type, False), None, None, None, None, None, None, None) class DistributedAttention(torch.nn.Module): @@ -203,8 +213,8 @@ def pre_hook_fun(grad): query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx, batch_dim_idx, None, self.overlap_handles, 'q') self.layer_sync(key) - key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, batch_dim_idx, None, self.overlap_handles, - 'k') + key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, batch_dim_idx, None, + self.overlap_handles, 'k') if self.sp_overlap_comm: self.dafult_stream.wait_stream(self.sp_stream) @@ -226,8 +236,8 @@ def pre_hook_fun(grad): context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) - output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, batch_dim_idx, self.sp_stream, - self.overlap_handles, 'o') + output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, batch_dim_idx, + self.sp_stream, self.overlap_handles, 'o') #out e.g., [s/p::h] return output diff --git a/tests/unit/runtime/zero/test_zero_context.py b/tests/unit/runtime/zero/test_zero_context.py index ec9e9e94aeaf..1d4fcd60022c 100644 --- a/tests/unit/runtime/zero/test_zero_context.py +++ b/tests/unit/runtime/zero/test_zero_context.py @@ -218,9 +218,9 @@ def test_throughput_calculation(self): engine.tput_timer.stop(global_step=global_step) duration = engine.tput_timer.end_time - engine.tput_timer.start_time # step elapsed time is reset after gradient accumulation steps - assert engine.tput_timer.step_elapsed_time == ( - 0 if engine.tput_timer.global_step_count != engine.tput_timer.start_step else current_duration + - duration) + assert engine.tput_timer.step_elapsed_time == (0 if engine.tput_timer.global_step_count + != engine.tput_timer.start_step else current_duration + + duration) assert engine.tput_timer.total_elapsed_time == total_duration + duration def test_ext_param_getattr(self): From 89b119e56d19031112fd4946fd994f71aac927dc Mon Sep 17 00:00:00 2001 From: Logan Adams Date: Wed, 7 Aug 2024 15:29:24 -0700 Subject: [PATCH 03/54] Formatting fixes --- deepspeed/__init__.py | 4 ++-- deepspeed/autotuning/autotuner.py | 12 ++++++------ deepspeed/elasticity/elastic_agent.py | 4 ++-- deepspeed/module_inject/replace_module.py | 7 +++---- deepspeed/runtime/config.py | 4 ++-- deepspeed/runtime/eigenvalue.py | 4 ++-- deepspeed/runtime/pipe/engine.py | 7 +++---- deepspeed/runtime/utils.py | 4 ++-- tests/unit/runtime/zero/test_zero_context.py | 6 +++--- 9 files changed, 25 insertions(+), 27 deletions(-) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index b9c44e076252..fe0043547860 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -152,8 +152,8 @@ def initialize(args=None, if hasattr(args, "deepscale_config") and args.deepscale_config is not None: logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************") if hasattr(args, "deepspeed_config"): - assert (args.deepspeed_config - is None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config" + assert (args.deepspeed_config is + None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config" args.deepspeed_config = args.deepscale_config args.deepscale_config = None diff --git a/deepspeed/autotuning/autotuner.py b/deepspeed/autotuning/autotuner.py index a72b3c951e97..dfd195bc37eb 100755 --- a/deepspeed/autotuning/autotuner.py +++ b/deepspeed/autotuning/autotuner.py @@ -248,8 +248,8 @@ def mp_size(self): return self.autotuning_config.mp_size def max_train_micro_batch_size_per_gpu(self): - if self.max_train_batch_size() and self.max_train_batch_size( - ) > 0: # if the user specifies a max_train_batch_size + if self.max_train_batch_size( + ) and self.max_train_batch_size() > 0: # if the user specifies a max_train_batch_size max_train_micro_batch_size = self.max_train_batch_size() * self.mp_size() // ( self.exp_num_gpus * self.exp_num_nodes) # gradient accumulation steps >=1 return min(self.autotuning_config.max_train_micro_batch_size_per_gpu, max_train_micro_batch_size) @@ -964,8 +964,8 @@ def get_min_max_micro_batch_size(self, stage, min_micro_batch_size, calculated_m low = mid + 1 self.update_records(tuning_space_name, exp, metric_val, 1) used_micro_batch_sizes.append(mid) - if prev_metric_val and ((metric_val - prev_metric_val) / - prev_metric_val) < METRIC_PERCENT_DIFF_CONST: + if prev_metric_val and ( + (metric_val - prev_metric_val) / prev_metric_val) < METRIC_PERCENT_DIFF_CONST: logger.info(f"performance plateaus at mbs = {low}") break prev_metric_val = metric_val @@ -1026,8 +1026,8 @@ def get_tuning_micro_batch_size_list(self, min_micro_batch_size, max_micro_batch # NUM_GPUS=$(( ${NUM_WORKERS} * ${NUM_GPUS_PER_WORKER} )) # DP_SIZE=$(( ${NUM_GPUS} / (${PP_SIZE} * ${MP_SIZE}) )) # GRAD_ACC_STEPS=$(( ${TARGET_GLOBAL_BATCH_SIZE} / (${BATCH_SIZE} * ${DP_SIZE}) )) - if self.max_train_batch_size() and self.max_train_batch_size( - ) > 0: # if the user specifies a max_train_batch_size + if self.max_train_batch_size( + ) and self.max_train_batch_size() > 0: # if the user specifies a max_train_batch_size max_train_batch_size_per_gpu = self.max_train_batch_size() * self.mp_size() // (self.exp_num_gpus * self.exp_num_nodes) else: diff --git a/deepspeed/elasticity/elastic_agent.py b/deepspeed/elasticity/elastic_agent.py index 0a8cf09db8a4..039b999dfeca 100644 --- a/deepspeed/elasticity/elastic_agent.py +++ b/deepspeed/elasticity/elastic_agent.py @@ -158,8 +158,8 @@ def _invoke_run(self, role: str = "default") -> RunResult: f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish.") self._exit_barrier() return run_result - elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED} or len(participants) > len( - rdzv_handler._state_holder.state.participants): + elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED + } or len(participants) > len(rdzv_handler._state_holder.state.participants): if self._remaining_restarts > 0: log.info(f"[{role}] Worker group {state.name}. " f"{self._remaining_restarts}/{spec.max_restarts} attempts left;" diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 29f8e1c22945..64dc5479940c 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -486,10 +486,9 @@ def conv2d_parallel_shard_weights(model, rank, world_size): if not dist.is_initialized() or dist.get_rank() == 0: print("Saving tp-sharded checkpoints") torch.save( - OrderedDict({ - k: v - for k, v in dict(replaced_module.state_dict()).items() if transformer_name not in k - }), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}') + OrderedDict({k: v + for k, v in dict(replaced_module.state_dict()).items() + if transformer_name not in k}), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}') dtype_reprs = { torch.float32: 'float32', diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 19f90b4d7c3b..b49b4a8b6086 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -1010,8 +1010,8 @@ def _do_error_check(self): self.gradient_accumulation_steps), "DeepSpeedConfig: {} is not defined".format(GRADIENT_ACCUMULATION_STEPS) if self.zero_enabled: - assert (self.zero_optimization_stage - <= ZeroStageEnum.max_stage), "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format( + assert (self.zero_optimization_stage <= + ZeroStageEnum.max_stage), "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format( ZeroStageEnum.max_stage) if self.fp16_master_weights_and_gradients: diff --git a/deepspeed/runtime/eigenvalue.py b/deepspeed/runtime/eigenvalue.py index 2361c21d756c..df63854dd1ca 100755 --- a/deepspeed/runtime/eigenvalue.py +++ b/deepspeed/runtime/eigenvalue.py @@ -110,8 +110,8 @@ def compute_eigenvalue(self, module, device=None, scale=1.0): eigenvalue_current, eigenvalue_previous = 1., 0. while (i < self.max_iter) and abs(eigenvalue_current) > 0 and (abs( - (eigenvalue_current - eigenvalue_previous) / eigenvalue_current) - >= self.tol): # test convergence criteria + (eigenvalue_current - eigenvalue_previous) / eigenvalue_current) >= + self.tol): # test convergence criteria eigenvalue_previous = eigenvalue_current Hv = torch.autograd.grad(grads, params, grad_outputs=v, only_inputs=True, retain_graph=True) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index b62d087cfbba..c627846b743c 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -632,10 +632,9 @@ def _aggregate_total_loss(self): self.dp_group_loss = losses[0].clone().detach() agg_loss = losses[1].clone().detach() if additional_losses is not None: - self.agg_additional_losses = OrderedDict({ - name: losses[2 + i].clone().detach() - for i, name in enumerate(additional_losses.keys()) - }) + self.agg_additional_losses = OrderedDict( + {name: losses[2 + i].clone().detach() + for i, name in enumerate(additional_losses.keys())}) return agg_loss def set_dataloader(self, loader): diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index efa35219c01a..2c01c3475a70 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -257,8 +257,8 @@ def has_overflow(self, params, has_moe_params=None): elif self.mpu is not None: if self.deepspeed is not None: using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') - if (using_pipeline and self.deepspeed.pipeline_enable_backward_allreduce - is False) or (not using_pipeline and self.deepspeed.enable_backward_allreduce is False): + if (using_pipeline and self.deepspeed.pipeline_enable_backward_allreduce is False) or ( + not using_pipeline and self.deepspeed.enable_backward_allreduce is False): dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_data_parallel_group()) dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_model_parallel_group()) elif self.deepspeed is not None and self.deepspeed.enable_backward_allreduce is False: diff --git a/tests/unit/runtime/zero/test_zero_context.py b/tests/unit/runtime/zero/test_zero_context.py index 1d4fcd60022c..ec9e9e94aeaf 100644 --- a/tests/unit/runtime/zero/test_zero_context.py +++ b/tests/unit/runtime/zero/test_zero_context.py @@ -218,9 +218,9 @@ def test_throughput_calculation(self): engine.tput_timer.stop(global_step=global_step) duration = engine.tput_timer.end_time - engine.tput_timer.start_time # step elapsed time is reset after gradient accumulation steps - assert engine.tput_timer.step_elapsed_time == (0 if engine.tput_timer.global_step_count - != engine.tput_timer.start_step else current_duration + - duration) + assert engine.tput_timer.step_elapsed_time == ( + 0 if engine.tput_timer.global_step_count != engine.tput_timer.start_step else current_duration + + duration) assert engine.tput_timer.total_elapsed_time == total_duration + duration def test_ext_param_getattr(self): From 0beff2479933e7978aa5051f2fec73380a4788b6 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Thu, 29 Aug 2024 16:54:17 -0400 Subject: [PATCH 04/54] add FPDT --- .../activation_checkpointing/checkpointing.py | 3 +- deepspeed/sequence/fpdt_layer.py | 1165 +++++++++++++++++ deepspeed/sequence/layer.py | 37 +- 3 files changed, 1202 insertions(+), 3 deletions(-) create mode 100644 deepspeed/sequence/fpdt_layer.py diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index f955cf5ebcad..866b444a35a5 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -369,7 +369,8 @@ def is_activation_to_checkpoint(item): Is an activation to be checkpointed """ global mp_size - return torch.is_tensor(item) and item.is_floating_point() and item.numel() >= mp_size + extra_flag = (not hasattr(item, 'no_checkpointing')) or (hasattr(item, 'no_checkpointing') and item.no_checkpointing == False) + return torch.is_tensor(item) and item.is_floating_point() and item.numel() >= mp_size and extra_flag def partition_activations(args, cpu_checkpoint, contiguous_checkpoint): diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py new file mode 100644 index 000000000000..7c88468fc947 --- /dev/null +++ b/deepspeed/sequence/fpdt_layer.py @@ -0,0 +1,1165 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from os import kill +from turtle import heading +from unittest import skip +import torch + +from typing import Optional, Any, Tuple +from torch import Tensor +from torch.nn import Module + +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator + +from packaging import version +from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward +from einops import rearrange + + +def _rotate_half(x): + """ + change sign so the last dimension becomes [-odd, +even] + """ + x = rearrange(x, '... (j d) -> ... j d', j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def _rotate_half_backward(x): + """ + change sign so the last dimension becomes [-odd, +even] + """ + x = rearrange(x, '... (j d) -> ... j d', j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((x2, -x1), dim=-1) + + +def apply_rotary_pos_emb(t, freqs_cos, freqs_sin): + """ + input tensor t is of shape [seq_length, ..., dim] + rotary positional embeding tensor freqs is of shape [seq_length, ..., dim] + check https://kexue.fm/archives/8265 for detailed formulas + """ + rot_dim = freqs_cos.shape[-1] + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + t = (t * freqs_cos) + (_rotate_half(t) * freqs_sin) + + res = t if t_pass.shape[-1] == 0 else torch.cat((t, t_pass), dim=-1) + return res + + +def apply_rotary_pos_emb_backward(grad_output, freqs_cos, freqs_sin): + rot_dim = freqs_cos.shape[-1] + grad, grad_pass = grad_output[..., :rot_dim], grad_output[..., rot_dim:] + grad_t = (grad * freqs_cos) + (_rotate_half_backward(grad * freqs_sin)) + grad = grad_t if grad_pass.shape[-1] == 0 else torch.cat((grad_t, grad_pass), dim=-1) + return grad + + +# @torch.jit.script +def _update_out_and_lse( + out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + + block_out = block_out.to(torch.float32) + block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + + new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + + out = torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out + + lse = new_lse + return out, lse + + +def update_out_and_lse( + out: Optional[torch.Tensor], + lse: Optional[torch.Tensor], + block_out: torch.Tensor, + block_lse: torch.Tensor, + slice_=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if out is None: + if slice_ is not None: + raise RuntimeError("first update_out_and_lse should not pass slice_ args") + out = block_out.to(torch.float32) + lse = block_lse.permute(0, 2, 1).contiguous().unsqueeze(dim=-1).contiguous() + elif slice_ is not None: + slice_out, slice_lse = out[slice_], lse[slice_] + slice_out, slice_lse = _update_out_and_lse( + slice_out, slice_lse, block_out, block_lse + ) + out[slice_], lse[slice_] = slice_out, slice_lse + else: + out, lse = _update_out_and_lse(out, lse, block_out, block_lse) + return out, lse + + +def single_all_to_all(input_, scatter_idx, gather_idx, group): + seq_world_size = dist.get_world_size(group) + if scatter_idx < 2: + bs, global_seq_len, local_head, head_dim = input_.shape + input_t = input_.reshape( + [bs, seq_world_size, global_seq_len // seq_world_size, local_head, head_dim] + ).contiguous() + input_t = input_t.permute(1, 0, 2, 3, 4).contiguous() + else: + bs, local_seq_len, total_head, head_dim = input_.shape + input_t = input_.reshape( + [bs, local_seq_len, seq_world_size, total_head // seq_world_size, head_dim] + ).contiguous() + input_t = input_t.permute(2, 0, 1, 3, 4).contiguous() + + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=group) + + if scatter_idx < 2: + output = output.permute(1, 2, 0, 3, 4).contiguous() + output = output.reshape(bs, global_seq_len // seq_world_size, seq_world_size * local_head, head_dim).contiguous() + else: + output = output.permute(1, 0, 2, 3, 4).contiguous() + output = output.reshape(bs, seq_world_size * local_seq_len, total_head // seq_world_size, head_dim).contiguous() + return output + + +class FPDT_InputConstruct(torch.nn.Module): + def __init__( + self, + tokens, labels, loss_mask, attention_mask, position_ids, args, sp_size, sp_rank + ) -> None: + + super(FPDT_InputConstruct, self).__init__() + self.tokens = tokens + self.labels = labels + self.loss_mask = loss_mask + self.attention_mask = attention_mask + self.position_ids = position_ids + global_seq_len = tokens.shape[1] + batch_size = tokens.shape[0] + assert global_seq_len % sp_size == 0 + assert global_seq_len % args.ds_sequence_parallel_fpdt_chunk_size == 0 + num_chunk_per_gpu = global_seq_len // args.ds_sequence_parallel_fpdt_chunk_size + local_seq_len = global_seq_len // sp_size + assert local_seq_len % num_chunk_per_gpu == 0 + + self.num_chunk_per_gpu = num_chunk_per_gpu + self.chunk_size = local_seq_len // num_chunk_per_gpu + self.sp_size = sp_size + self.sp_rank = sp_rank + self.global_seq_len = global_seq_len + self.local_seq_len = local_seq_len + self.batch_size = batch_size + self.device = tokens.device + + def generate(self): + device = self.device + totalChunks = self.global_seq_len // self.chunk_size + token_chunk_idx = torch.arange(self.global_seq_len, device=device, dtype=torch.int) // self.chunk_size + chunk_to_gpu = torch.arange(totalChunks, device=device, dtype=torch.int) + chunk_to_gpu = chunk_to_gpu.reshape(self.num_chunk_per_gpu, -1).t().contiguous() + + gather_chunk = chunk_to_gpu.flatten().unsqueeze(1).contiguous() + mask = gather_chunk == token_chunk_idx + + indices = mask.nonzero(as_tuple=False) + gather_indices = indices[:, 0] + token_chunk_indices = indices[:, 1] + indices = torch.cat([token_chunk_indices[gather_indices == i] for i in range(gather_chunk.shape[0])]) + load_balanced_loss_mask = self.loss_mask[:, indices] + + indices = indices.reshape(-1, self.chunk_size)[self.num_chunk_per_gpu*self.sp_rank:self.num_chunk_per_gpu*(self.sp_rank + 1)].flatten().contiguous() + load_balanced_tokens = self.tokens[:, indices] + load_balanced_labels = self.labels[:, indices] + + load_balanced_attention_mask = self.attention_mask if self.attention_mask is not None else None + load_balanced_position_ids = self.position_ids[:, indices] + + return load_balanced_tokens, load_balanced_labels, load_balanced_loss_mask, load_balanced_attention_mask, load_balanced_position_ids + + +class _FPDTGPUAttentionImpl_(torch.autograd.Function): + generate_vmap_rule = False + + @staticmethod + def forward(ctx: Any, + layernorm_output, + attention_mask, + inference_params, + rotary_pos_emb, + spg, + scatter_idx, + gather_idx, + hidden_size, + projection_size, + hidden_size_per_attention_head, + kv_projection_size, + qkv_linear_weight, + qkv_linear_bias, + dropout, + num_chunks=8, cpu_offloading=True): + + do_save = layernorm_output.requires_grad + + pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3) + ctx.pos_emb_cos = pos_emb_cos + ctx.pos_emb_sin = pos_emb_sin + + with torch.no_grad(): + per_gpu_seq_len = layernorm_output.shape[0] + chunk_size = per_gpu_seq_len // num_chunks + assert chunk_size * num_chunks == per_gpu_seq_len + assert attention_mask is None + ctx.num_chunks = num_chunks + ctx.cpu_offloading = cpu_offloading + ctx.spg = spg + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + + device = get_accelerator().current_device_name() + ctx.device = device + ctx.dtype = layernorm_output.dtype + ctx.projection_size = projection_size + ctx.kv_projection_size = kv_projection_size + + global_q = [] + global_k = [] + global_v = [] + + ctx.softmax_scale = hidden_size_per_attention_head ** (-0.5) + + ctx.dropout_p = dropout + ctx.window_size = (-1, -1) + ctx.alibi_slopes = None + + batch_size = layernorm_output.shape[1] + + global_o = [None for _ in range(num_chunks)] + global_lse = [None for _ in range(num_chunks)] + + for i in range(num_chunks): + + st = chunk_size * i + ed = st + chunk_size + + qkv_chunk = torch.matmul(layernorm_output[st:ed], qkv_linear_weight.t()) + qkv_linear_bias + + q_chunk = qkv_chunk[:, :, :projection_size].contiguous().reshape(qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + q_chunk = single_all_to_all(q_chunk, scatter_idx, gather_idx, spg) + global_q_chunk_len = q_chunk.shape[1] + q_chunk = apply_rotary_pos_emb(q_chunk, pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) + global_q.append(q_chunk) + + k_chunk = qkv_chunk[:, :, projection_size:projection_size+kv_projection_size].contiguous().reshape(qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + k_chunk = single_all_to_all(k_chunk, scatter_idx, gather_idx, spg) + k_chunk = apply_rotary_pos_emb(k_chunk, pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) + global_k.append(k_chunk) + + v_chunk = qkv_chunk[:, :, projection_size+kv_projection_size:].contiguous().reshape(qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + v_chunk = single_all_to_all(v_chunk, scatter_idx, gather_idx, spg) + global_v.append(v_chunk) + + for k_i in range(len(global_k)): + causal_chunk = i == k_i + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( + global_q[i], + global_k[k_i], + global_v[k_i], + ctx.dropout_p, + ctx.softmax_scale, + causal=causal_chunk, + window_size=ctx.window_size, + softcap=0.0, + alibi_slopes=ctx.alibi_slopes, + return_softmax=False + ) + global_o[i], global_lse[i] = update_out_and_lse(global_o[i], global_lse[i], block_out, block_lse) + + global_o[i] = global_o[i].to(q_chunk.dtype) + + output = [None for i in range(num_chunks)] + + for i in range(num_chunks): + global_lse[i] = global_lse[i][:, :, :, 0].permute(0, 2, 1).contiguous() + output[i] = single_all_to_all(global_o[i].to(ctx.dtype).contiguous(), gather_idx, scatter_idx, spg) + output = torch.cat(output, dim=1) + + head_dim = output.shape[-1] + + if do_save: + ctx.save_for_backward(layernorm_output) + ctx.global_q = global_q + ctx.global_k = global_k + ctx.global_v = global_v + ctx.attn_output = global_o + ctx.attn_lse = global_lse + ctx.head_dim = head_dim + ctx.batch_size = batch_size + + ctx.qkv_linear_weight = qkv_linear_weight + ctx.qkv_linear_bias = qkv_linear_bias + + return output + + + @staticmethod + def backward(ctx, grad_output): + + num_chunks = ctx.num_chunks + device = ctx.device + dtype = ctx.dtype + spg = ctx.spg + scatter_idx = ctx.scatter_idx + gather_idx = ctx.gather_idx + softmax_scale = ctx.softmax_scale + dropout_p = ctx.dropout_p + window_size = ctx.window_size + alibi_slopes = ctx.alibi_slopes + + projection_size = ctx.projection_size + kv_projection_size = ctx.kv_projection_size + + layernorm_output = ctx.saved_tensors[0] + + global_q = ctx.global_q + global_k = ctx.global_k + global_v = ctx.global_v + attn_output = ctx.attn_output + lse = ctx.attn_lse + + qkv_linear_weight = ctx.qkv_linear_weight + qkv_linear_bias = ctx.qkv_linear_bias + + input_chunk_size = layernorm_output.shape[0] // num_chunks + grad_layernorm_output = [torch.zeros((input_chunk_size, layernorm_output.shape[1], layernorm_output.shape[2]), device=device, dtype=dtype) for _ in range(num_chunks)] + + grad_global_attn_output = [] + chunk_size = grad_output.shape[1] // num_chunks + + for i in range(num_chunks): + st = chunk_size * i + ed = st + chunk_size + grad_global_attn_output.append(single_all_to_all(grad_output[:, st:ed].contiguous(), scatter_idx, gather_idx, spg)) + + del grad_output + + dq = [torch.zeros(global_q[0].shape, dtype=torch.float, device=device) for _ in range(num_chunks)] + dk = [torch.zeros(global_q[0].shape, dtype=torch.float, device=device) for _ in range(num_chunks)] + dv = [torch.zeros(global_q[0].shape, dtype=torch.float, device=device) for _ in range(num_chunks)] + + grad_qkv_linear_weight = torch.zeros(qkv_linear_weight.shape, device=qkv_linear_weight.device, dtype=torch.float) + grad_qkv_linear_bias = torch.zeros(qkv_linear_bias.shape, device=qkv_linear_weight.device, dtype=torch.float) + + for i in range(num_chunks): + k_chunk = global_k[i] + v_chunk = global_v[i] + + for q_i in range(num_chunks): + no_computation = q_i < i + if no_computation: + continue + + causal_chunk = q_i == i + + q_chunk = global_q[q_i] + attn_output_chunk = attn_output[q_i] + lse_chunk = lse[q_i] + dout = grad_global_attn_output[q_i] + + dq_this = torch.zeros(global_q[0].shape, dtype=dtype, device=device) + dk_this = torch.zeros(global_k[0].shape, dtype=dtype, device=device) + dv_this = torch.zeros(global_v[0].shape, dtype=dtype, device=device) + + _flash_attn_backward( + dout, + q_chunk, + k_chunk, + v_chunk, + attn_output_chunk, + lse_chunk, + dq_this, + dk_this, + dv_this, + dropout_p, + softmax_scale, + causal_chunk, + window_size, + softcap=0.0, + alibi_slopes=alibi_slopes, + deterministic=False, + rng_state=None + ) + + dq[q_i].add_(dq_this.to(torch.float)) + dk[i].add_(dk_this.to(torch.float)) + dv[i].add_(dv_this.to(torch.float)) + + dk_seq_len = dk[i].shape[1] + dk[i] = apply_rotary_pos_emb_backward(dk[i].to(dtype), ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) + dv[i] = dv[i].to(dtype) + dk[i] = single_all_to_all(dk[i].contiguous(), gather_idx, scatter_idx, spg) + dv[i] = single_all_to_all(dv[i].contiguous(), gather_idx, scatter_idx, spg) + + input_st = i * input_chunk_size + input_ed = input_st + input_chunk_size + + input_chunk = layernorm_output[input_st:input_ed].reshape(-1, layernorm_output.shape[-1]) + + dk[i] = dk[i].flatten(2).permute(1, 0, 2) + dv[i] = dv[i].flatten(2).permute(1, 0, 2) + l, b = dk[i].shape[0], dk[i].shape[1] + grad_qkv_linear_weight[projection_size:projection_size+kv_projection_size].add_(torch.matmul(dk[i].reshape(l*b, -1).t(), input_chunk)) + grad_qkv_linear_weight[projection_size+kv_projection_size:].add_(torch.matmul(dv[i].reshape(l*b, -1).t(), input_chunk)) + grad_qkv_linear_bias[projection_size:projection_size+kv_projection_size].add_(dk[i].sum(0).sum(0)) + grad_qkv_linear_bias[projection_size+kv_projection_size:].add_(dv[i].sum(0).sum(0)) + + grad_layernorm_output[i].add_(torch.matmul(dk[i], qkv_linear_weight[projection_size:projection_size+kv_projection_size])) + grad_layernorm_output[i].add_(torch.matmul(dv[i], qkv_linear_weight[projection_size+kv_projection_size:])) + + dk[i] = None + dv[i] = None + + for i in range(num_chunks): + dq_seq_len = dq[i].shape[1] + dq[i] = apply_rotary_pos_emb_backward(dq[i].to(dtype), ctx.pos_emb_cos[:, dq_seq_len * i:dq_seq_len * (i + 1)], ctx.pos_emb_sin[:, dq_seq_len * i:dq_seq_len * (i + 1)]) + + dq[i] = single_all_to_all(dq[i].to(dtype).contiguous(), gather_idx, scatter_idx, spg) + + input_chunk = layernorm_output[:input_chunk_size].reshape(-1, layernorm_output.shape[-1]) + layernorm_output = layernorm_output[input_chunk_size:] + + dq[i] = dq[i].flatten(2).permute(1, 0, 2) + l, b = dq[i].shape[0], dq[i].shape[1] + grad_qkv_linear_weight[:projection_size].add_(torch.matmul(dq[i].reshape(l*b, -1).t(), input_chunk)) + grad_qkv_linear_bias[:projection_size].add_(dq[i].sum(0).sum(0)) + + grad_layernorm_output[i].add_(torch.matmul(dq[i], qkv_linear_weight[:projection_size])) + + dq[i] = None + + + return torch.cat(grad_layernorm_output, dim=0).to(dtype), None, None, None, None, None, None, None, None, None, None, grad_qkv_linear_weight.to(dtype), grad_qkv_linear_bias.to(dtype), None, None, None + + + +class SequenceChunk: + def __init__(self, chunk: torch.Tensor, device=None, is_in_use=False): + + self.chunk_shape = chunk.shape + self.chunk_dtype = chunk.dtype + self.device = chunk.device if device is None else device + + cpu_chunk = torch.empty(chunk.shape, dtype=chunk.dtype, device='cpu', pin_memory=True) + if chunk.is_cuda: + cpu_chunk.copy_(chunk, non_blocking=True) + else: + cpu_chunk = chunk + + self.cpu_chunk = cpu_chunk + + self.gpu_chunk = chunk if is_in_use else None + + def load_to_gpu(self): + assert self.gpu_chunk is None + if self.gpu_chunk is not None: + pass + else: + gpu_chunk = torch.empty(self.chunk_shape, device=self.device, dtype=self.chunk_dtype) + gpu_chunk.copy_(self.cpu_chunk, non_blocking=True) + self.gpu_chunk = gpu_chunk + + def get_gpu_chunk(self): + assert self.gpu_chunk is not None and self.gpu_chunk.device == self.device + return self.gpu_chunk + + def check_gpu_chunk(self,): + assert (self.gpu_chunk is not None) and (self.gpu_chunk.device == self.device), f"gpu_chunk {self.gpu_chunk is not None} shound be on {self.device}, but it is now on {self.gpu_chunk.device}" + return True + + def offload(self): + assert self.gpu_chunk is not None and self.gpu_chunk.device == self.device + del self.gpu_chunk + self.gpu_chunk = None + + def overwrite_to_cpu(self): + assert self.gpu_chunk is not None and self.gpu_chunk.device == self.device + self.cpu_chunk.copy_(self.gpu_chunk, non_blocking=True) + + +class _FPDTGPUOffloadingAttentionImpl_(torch.autograd.Function): + generate_vmap_rule = False + + @staticmethod + def forward(ctx: Any, + layernorm_output, + attention_mask, + inference_params, + rotary_pos_emb, + spg, + scatter_idx, + gather_idx, + hidden_size, + projection_size, + hidden_size_per_attention_head, + kv_projection_size, + qkv_linear_weight, + qkv_linear_bias, + dropout, + num_chunks=8, cpu_offloading=True): + + do_save = layernorm_output.requires_grad + + pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3) + ctx.pos_emb_cos = pos_emb_cos + ctx.pos_emb_sin = pos_emb_sin + with torch.no_grad(): + per_gpu_seq_len = layernorm_output.shape[0] + chunk_size = per_gpu_seq_len // num_chunks + assert chunk_size * num_chunks == per_gpu_seq_len + assert attention_mask is None + ctx.num_chunks = num_chunks + ctx.cpu_offloading = cpu_offloading + ctx.spg = spg + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + + ctx.chunk_size = chunk_size + device = get_accelerator().current_device_name() + ctx.device = device + ctx.dtype = layernorm_output.dtype + ctx.projection_size = projection_size + ctx.kv_projection_size = kv_projection_size + + global_q = [] + global_k = [] + global_v = [] + + ctx.softmax_scale = hidden_size_per_attention_head ** (-0.5) + + ctx.dropout_p = dropout + ctx.window_size = (-1, -1) + ctx.alibi_slopes = None + + batch_size = layernorm_output.shape[1] + + global_o = [] + global_lse = [] + + layernorm_output_cpu = [] + final_output = [] + + offload_stream = get_accelerator().Stream() + general_offload_stream = get_accelerator().Stream() + compute_stream = get_accelerator().default_stream() + + q_compute_chunk_idx = 0 + kv_compute_chunk_idx = 0 + for i in range(num_chunks): + + qkv_chunk = torch.matmul(layernorm_output[:chunk_size], qkv_linear_weight.t()) + qkv_linear_bias # torch.Size([18126, 1, 12288]) + + with torch.cuda.stream(general_offload_stream): + layernorm_output_cpu.append(SequenceChunk(layernorm_output[:chunk_size])) + + layernorm_output = layernorm_output[chunk_size:] + + q_chunk = qkv_chunk[:, :, :projection_size].contiguous().reshape(qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + q_chunk = single_all_to_all(q_chunk, scatter_idx, gather_idx, spg) + global_q_chunk_len = q_chunk.shape[1] + + k_chunk = qkv_chunk[:, :, projection_size:projection_size+kv_projection_size].contiguous().reshape(qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + k_chunk = single_all_to_all(k_chunk, scatter_idx, gather_idx, spg) + + v_chunk = qkv_chunk[:, :, projection_size+kv_projection_size:].contiguous().reshape(qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + v_chunk = single_all_to_all(v_chunk, scatter_idx, gather_idx, spg) + + torch.distributed.barrier() # torch.cuda.synchronize() + + pos_emb_cos_chunk = pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] + pos_emb_sin_chunk = pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] + + q_chunk = apply_rotary_pos_emb(q_chunk, pos_emb_cos_chunk, pos_emb_sin_chunk) + k_chunk = apply_rotary_pos_emb(k_chunk, pos_emb_cos_chunk, pos_emb_sin_chunk) + + compute_stream.wait_stream(offload_stream) + compute_stream.synchronize() + with torch.cuda.stream(offload_stream): + global_q.append(SequenceChunk(q_chunk, is_in_use=True)) + global_k.append(SequenceChunk(k_chunk, is_in_use=True)) + global_v.append(SequenceChunk(v_chunk, is_in_use=True)) + + del qkv_chunk + + cur_attn_output = None + cur_attn_lse = None + for k_i in range(len(global_k)): + causal_chunk = i == k_i + with torch.cuda.stream(compute_stream): + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( + global_q[q_compute_chunk_idx].get_gpu_chunk(), + global_k[kv_compute_chunk_idx].get_gpu_chunk(), + global_v[kv_compute_chunk_idx].get_gpu_chunk(), + ctx.dropout_p, + ctx.softmax_scale, + causal=causal_chunk, + window_size=ctx.window_size, + softcap=0.0, + alibi_slopes=ctx.alibi_slopes, + return_softmax=False + ) + cur_attn_output, cur_attn_lse = update_out_and_lse(cur_attn_output, cur_attn_lse, block_out, block_lse) + + can_offload_kv = True + if k_i != (len(global_k) - 1) or i != (num_chunks - 1): + if k_i != (len(global_k) - 1): + next_kv_compute_chunk_idx = k_i + 1 + else: + next_kv_compute_chunk_idx = 0 + + if next_kv_compute_chunk_idx == kv_compute_chunk_idx: + can_offload_kv = False + else: + if next_kv_compute_chunk_idx != (len(global_k) - 1): + with torch.cuda.stream(offload_stream): + global_k[next_kv_compute_chunk_idx].load_to_gpu() + global_v[next_kv_compute_chunk_idx].load_to_gpu() + + if i == num_chunks - 1 and k_i == num_chunks - 1: + with torch.cuda.stream(offload_stream): + global_q[0].load_to_gpu() + global_k[0].load_to_gpu() + global_v[0].load_to_gpu() + global_o[0].load_to_gpu() + global_lse[0].load_to_gpu() + + compute_stream.wait_stream(offload_stream) + compute_stream.synchronize() + + if can_offload_kv: + global_k[kv_compute_chunk_idx].offload() + global_v[kv_compute_chunk_idx].offload() + kv_compute_chunk_idx = next_kv_compute_chunk_idx + + global_q[q_compute_chunk_idx].offload() + q_compute_chunk_idx += 1 + + all2all_output = single_all_to_all(cur_attn_output.to(ctx.dtype).contiguous(), gather_idx, scatter_idx, spg) + final_output.append(all2all_output) + with torch.cuda.stream(general_offload_stream): + global_o.append(SequenceChunk(cur_attn_output.to(ctx.dtype))) + global_lse.append(SequenceChunk(cur_attn_lse[:, :, :, 0].permute(0, 2, 1).contiguous())) + + compute_stream.wait_stream(general_offload_stream) + compute_stream.synchronize() + + final_output = torch.cat(final_output, dim=1) + + head_dim = final_output.shape[-1] + + if do_save: + ctx.layernorm_output = layernorm_output_cpu + ctx.global_q = global_q + ctx.global_k = global_k + ctx.global_v = global_v + ctx.attn_output = global_o + ctx.attn_lse = global_lse + ctx.head_dim = head_dim + ctx.batch_size = batch_size + + ctx.qkv_linear_weight = qkv_linear_weight + ctx.qkv_linear_bias = qkv_linear_bias + + return final_output + + + @staticmethod + def backward(ctx, grad_output): + num_chunks = ctx.num_chunks + device = grad_output.device + dtype = ctx.dtype + spg = ctx.spg + scatter_idx = ctx.scatter_idx + gather_idx = ctx.gather_idx + softmax_scale = ctx.softmax_scale + dropout_p = ctx.dropout_p + window_size = ctx.window_size + alibi_slopes = ctx.alibi_slopes + + projection_size = ctx.projection_size + kv_projection_size = ctx.kv_projection_size + + layernorm_output = ctx.layernorm_output + + global_q = ctx.global_q + global_k = ctx.global_k + global_v = ctx.global_v + attn_output = ctx.attn_output + lse = ctx.attn_lse + + qkv_linear_weight = ctx.qkv_linear_weight + qkv_linear_bias = ctx.qkv_linear_bias + + offload_stream = get_accelerator().Stream() + general_offload_stream = torch.cuda.Stream() + compute_stream = get_accelerator().default_stream() + + chunk_size = grad_output.shape[1] // num_chunks + assert chunk_size == layernorm_output[0].cpu_chunk.shape[0] + + grad_layernorm_output = [torch.zeros(layernorm_output[0].chunk_shape, device=device, dtype=dtype) for _ in range(num_chunks)] + + grad_global_attn_output = [None for _ in range(num_chunks)] + + q_compute_chunk_idx = 0 + kv_compute_chunk_idx = 0 + last_q_accum_idx = 0 + + with torch.cuda.stream(general_offload_stream): + layernorm_output[0].load_to_gpu() + grad_qkv_linear_weight = torch.zeros(qkv_linear_weight.shape, device=qkv_linear_weight.device, dtype=torch.float) + grad_qkv_linear_bias = torch.zeros(qkv_linear_bias.shape, device=qkv_linear_weight.device, dtype=torch.float) + + grad_global_attn_output_chunk = single_all_to_all(grad_output[:, :chunk_size].contiguous(), scatter_idx, gather_idx, spg) + torch.cuda.synchronize() + grad_output = grad_output[:, chunk_size:] + + with torch.cuda.stream(offload_stream): + grad_global_attn_output[0] = SequenceChunk(grad_global_attn_output_chunk, is_in_use=True) + dq = [SequenceChunk(torch.zeros(global_q[0].chunk_shape, dtype=torch.float, device=device), is_in_use=True)] + [SequenceChunk(torch.zeros(global_q[0].chunk_shape, dtype=torch.float, device='cpu', pin_memory=True), device) for _ in range(num_chunks - 1)] + dk_accum = torch.zeros(global_k[0].chunk_shape, dtype=torch.float, device=device) + dv_accum = torch.zeros(global_v[0].chunk_shape, dtype=torch.float, device=device) + + for i in range(num_chunks): + for q_i in range(num_chunks): + no_computation = q_i < i + if no_computation: + continue + + causal_chunk = q_i == i + + dq_this = torch.zeros(global_q[0].chunk_shape, dtype=dtype, device=device) + dk_this = torch.zeros(global_k[0].chunk_shape, dtype=dtype, device=device) + dv_this = torch.zeros(global_v[0].chunk_shape, dtype=dtype, device=device) + + with torch.cuda.stream(compute_stream): + _flash_attn_backward( + grad_global_attn_output[q_compute_chunk_idx].get_gpu_chunk(), + global_q[q_compute_chunk_idx].get_gpu_chunk(), + global_k[kv_compute_chunk_idx].get_gpu_chunk(), + global_v[kv_compute_chunk_idx].get_gpu_chunk(), + attn_output[q_compute_chunk_idx].get_gpu_chunk(), + lse[q_compute_chunk_idx].get_gpu_chunk(), + dq_this, + dk_this, + dv_this, + dropout_p, + softmax_scale, + causal_chunk, + window_size, + softcap=0.0, + alibi_slopes=alibi_slopes, + deterministic=False, + rng_state=None + ) + + if i != (len(global_k) - 1): + if q_i != (len(global_q) - 1): + next_q_compute_chunk_idx = q_i + 1 + else: + next_q_compute_chunk_idx = i + 1 + + can_offload_q = True + + if next_q_compute_chunk_idx == q_compute_chunk_idx: + can_offload_q = False + else: + with torch.cuda.stream(offload_stream): + if i > 0 or q_i > 0: + if can_offload_q and last_q_accum_idx != i: # the first q chunk calculate in the loop will be sent out, therefore we do not offload it + dq[last_q_accum_idx].offload() + dq[next_q_compute_chunk_idx].load_to_gpu() + global_q[next_q_compute_chunk_idx].load_to_gpu() + attn_output[next_q_compute_chunk_idx].load_to_gpu() + lse[next_q_compute_chunk_idx].load_to_gpu() + if grad_global_attn_output[next_q_compute_chunk_idx] is not None: + grad_global_attn_output[next_q_compute_chunk_idx].load_to_gpu() + + if grad_global_attn_output[next_q_compute_chunk_idx] is None: + grad_global_attn_output_chunk = single_all_to_all(grad_output[:, :chunk_size].contiguous(), scatter_idx, gather_idx, spg) + torch.distributed.barrier() + grad_output = grad_output[:, chunk_size:] + grad_global_attn_output[next_q_compute_chunk_idx] = SequenceChunk(grad_global_attn_output_chunk, is_in_use=True) + + compute_stream.wait_stream(offload_stream) + compute_stream.synchronize() + + with torch.cuda.stream(compute_stream): + dq[q_compute_chunk_idx].check_gpu_chunk() + dq[q_compute_chunk_idx].gpu_chunk.add_(dq_this) + dk_accum.add_(dk_this) + dv_accum.add_(dv_this) + + offload_stream.wait_stream(compute_stream) + with torch.cuda.stream(offload_stream): + dq[q_compute_chunk_idx].overwrite_to_cpu() + + if can_offload_q: + global_q[q_compute_chunk_idx].offload() + attn_output[q_compute_chunk_idx].offload() + lse[q_compute_chunk_idx].offload() + grad_global_attn_output[q_compute_chunk_idx].offload() + + last_q_accum_idx = q_compute_chunk_idx + q_compute_chunk_idx = next_q_compute_chunk_idx + + compute_stream.wait_stream(offload_stream) + compute_stream.synchronize() + + dk_seq_len = dk_accum.shape[1] + dq_accum = apply_rotary_pos_emb_backward(dq[kv_compute_chunk_idx].get_gpu_chunk().to(dtype), ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) + dk_accum = apply_rotary_pos_emb_backward(dk_accum.to(dtype), ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) + dv_accum = dv_accum.to(dtype) + + dq_accum = single_all_to_all(dq_accum.contiguous(), gather_idx, scatter_idx, spg) + dk_accum = single_all_to_all(dk_accum.contiguous(), gather_idx, scatter_idx, spg) + dv_accum = single_all_to_all(dv_accum.contiguous(), gather_idx, scatter_idx, spg) + + general_offload_stream.synchronize() + compute_stream.wait_stream(general_offload_stream) + torch.distributed.barrier() # torch.cuda.synchronize() + + with torch.cuda.stream(compute_stream): + input_chunk = layernorm_output[i].get_gpu_chunk().reshape(-1, layernorm_output[i].chunk_shape[-1]) + + dq_accum = dq_accum.flatten(2).permute(1, 0, 2) + dk_accum = dk_accum.flatten(2).permute(1, 0, 2) + dv_accum = dv_accum.flatten(2).permute(1, 0, 2) + + l, b = dk_accum.shape[0], dk_accum.shape[1] + + grad_qkv_linear_weight[:projection_size].add_(torch.matmul(dq_accum.reshape(l*b, -1).t(), input_chunk)) + grad_qkv_linear_weight[projection_size:projection_size+kv_projection_size].add_(torch.matmul(dk_accum.reshape(l*b, -1).t(), input_chunk)) + grad_qkv_linear_weight[projection_size+kv_projection_size:].add_(torch.matmul(dv_accum.reshape(l*b, -1).t(), input_chunk)) + + grad_qkv_linear_bias[:projection_size].add_(dq_accum.sum(0).sum(0)) + grad_qkv_linear_bias[projection_size:projection_size+kv_projection_size].add_(dk_accum.sum(0).sum(0)) + grad_qkv_linear_bias[projection_size+kv_projection_size:].add_(dv_accum.sum(0).sum(0)) + + grad_layernorm_output[i].add_(torch.matmul(dq_accum, qkv_linear_weight[:projection_size])) + grad_layernorm_output[i].add_(torch.matmul(dk_accum, qkv_linear_weight[projection_size:projection_size+kv_projection_size])) + grad_layernorm_output[i].add_(torch.matmul(dv_accum, qkv_linear_weight[projection_size+kv_projection_size:])) + + del dq_accum, dk_accum, dv_accum + dk_accum = torch.zeros(global_k[i].chunk_shape, dtype=torch.float, device=device) + dv_accum = torch.zeros(global_v[i].chunk_shape, dtype=torch.float, device=device) + dq[kv_compute_chunk_idx].offload() + dq[kv_compute_chunk_idx] = None + + if i != (len(global_k) - 1): + next_kv_compute_chunk_idx = kv_compute_chunk_idx + 1 + with torch.cuda.stream(offload_stream): + global_k[next_kv_compute_chunk_idx].load_to_gpu() + global_v[next_kv_compute_chunk_idx].load_to_gpu() + + with torch.cuda.stream(general_offload_stream): + layernorm_output[next_kv_compute_chunk_idx].load_to_gpu() + + compute_stream.wait_stream(offload_stream) + compute_stream.synchronize() + + layernorm_output[kv_compute_chunk_idx].offload() + global_k[kv_compute_chunk_idx].offload() + global_v[kv_compute_chunk_idx].offload() + kv_compute_chunk_idx = next_kv_compute_chunk_idx + + return torch.cat(grad_layernorm_output, dim=0).to(dtype), None, None, None, None, None, None, None, None, None, None, grad_qkv_linear_weight.to(dtype), grad_qkv_linear_bias.to(dtype), None, None, None + + + +class FPDT_Attention(torch.nn.Module): + def __init__( + self, + config, + first_weight, + first_bias, + second_weight, + second_bias, + sequence_process_group, + gather_idx: int = 0, + scatter_idx: int = 2, + return_bias=True, + chunk_size=65536, + enable_offloading=True + ) -> None: + + super(FPDT_Attention, self).__init__() + self.spg = sequence_process_group + self.scatter_idx = scatter_idx + self.gather_idx = gather_idx + self.config = config + + self.projection_size = config.kv_channels * config.num_attention_heads + self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads + self.kv_projection_size = config.kv_channels * config.num_key_value_heads + self.hidden_size = config.hidden_size + + self.qkv_linear_weight = first_weight + self.qkv_linear_bias = first_bias + self.qkv_dense_weight = second_weight + self.qkv_dense_bias = second_bias + + self.reture_bias = return_bias + self.dropout = config.attention_dropout + + self.chunk_size = chunk_size + self.double_buffer = enable_offloading + + def forward(self, + layernorm_output, + attention_mask, + inference_params, + rotary_pos_emb, + cpu_offloading=True) -> Tensor: + """ forward + + Arguments: + query (Tensor): query input to the layer + key (Tensor): key input to the layer + value (Tensor): value input to the layer + args: other args + + Returns: + * output (Tensor): context output + """ + self.num_chunks_attn = layernorm_output.shape[0] * dist.get_world_size(self.spg) // self.chunk_size + + if not cpu_offloading: + output = _FPDTGPUAttentionImpl_.apply( + layernorm_output, + attention_mask, + inference_params, + rotary_pos_emb, + self.spg, + self.scatter_idx, + self.gather_idx, + self.hidden_size, + self.projection_size, + self.hidden_size_per_attention_head, + self.kv_projection_size, + self.qkv_linear_weight, + self.qkv_linear_bias, + self.dropout, + self.num_chunks_attn, cpu_offloading) + else: + output = _FPDTGPUOffloadingAttentionImpl_.apply( + layernorm_output, + attention_mask, + inference_params, + rotary_pos_emb, + self.spg, + self.scatter_idx, + self.gather_idx, + self.hidden_size, + self.projection_size, + self.hidden_size_per_attention_head, + self.kv_projection_size, + self.qkv_linear_weight, + self.qkv_linear_bias, + self.dropout, + self.num_chunks_attn, cpu_offloading) + + output = output.flatten(2).permute(1, 0, 2).contiguous() + + output = torch.matmul(output, self.qkv_dense_weight.t()) + if not self.reture_bias: + output += self.qkv_dense_bias + return output, self.qkv_dense_bias if self.reture_bias else None + + +@torch.jit.script +def bias_gelu(x): + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + +@torch.jit.script +def bias_gelu_back(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return ff*g + + +class FPDT_FFN(torch.autograd.Function): + generate_vmap_rule = False + + @staticmethod + def forward(ctx: Any, x, w1, b1, w2, b2, add_bias, chunk_size): + do_save = x.requires_grad + ctx.add_bias = add_bias + device = x.device + + with torch.no_grad(): + num_chunk = x.shape[0] // chunk_size + ctx.num_chunk = num_chunk + result = torch.empty(x.shape, device=device, dtype=x.dtype) + assert chunk_size * num_chunk == x.shape[0] + for i in range(num_chunk): + st = i * chunk_size + ed = st + chunk_size + x_ = torch.matmul(x[st:ed], w1.t()) + b1 + x_ = bias_gelu(x_) + if add_bias: + result[st:ed] = torch.matmul(x_, w2.t()) + b2 + else: + result[st:ed] = torch.matmul(x_, w2.t()) + + del x_ + + if do_save: + ctx.device = device + ctx.dtype = x.dtype + ctx.save_for_backward(x, w1, b1, w2, b2) + ctx.grad_x_shape = x.shape + return result.to(x.dtype), b2 if not add_bias else None + + @staticmethod + def backward(ctx, grad_output, grad_bias): + x, w1, b1, w2, b2 = ctx.saved_tensors + device = ctx.device + dtype = ctx.dtype + add_bias = ctx.add_bias + + num_chunk = ctx.num_chunk + chunk_size = x.shape[0] // num_chunk + assert chunk_size * num_chunk == grad_output.shape[0] + + grad_w2 = torch.zeros(w2.shape, device=device, dtype=torch.float) + grad_b2 = torch.zeros(b2.shape, device=device, dtype=torch.float) + grad_w1 = torch.zeros(w1.shape, device=device, dtype=torch.float) + grad_b1 = torch.zeros(b1.shape, device=device, dtype=torch.float) + + for i in range(num_chunk): + st = i * chunk_size + ed = st + chunk_size + x_chunk = x[st:ed] + + before_act = (torch.matmul(x_chunk, w1.t()) + b1) + before_act_2 = before_act ** 2 + tanh_out = torch.tanh(0.79788456 * before_act * (1 + 0.044715 * before_act_2)) + ff = 0.5 * before_act * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * before_act_2)) + 0.5 * (1 + tanh_out) + grad_w2.add_(torch.matmul(grad_output[st:ed].reshape(-1, grad_output.shape[2]).t(), (before_act * 0.5 * (1 + tanh_out)).reshape(-1, before_act.shape[2]))) + del before_act, before_act_2, tanh_out + + grad_inter = torch.matmul(grad_output[st:ed], w2) * ff + del ff + + grad_w1.add_(torch.matmul(grad_inter.reshape(-1, grad_inter.shape[2]).t(), x_chunk.reshape(-1, x.shape[2]))) + grad_b1.add_(grad_inter.sum(0).sum(0)) + + x[st:ed].copy_(torch.matmul(grad_inter, w1)) + + del grad_inter + + if add_bias: + grad_b2.add_(grad_output[st:ed].sum(0).sum(0)) + + return x, grad_w1.to(dtype), grad_b1.to(dtype), grad_w2.to(dtype), grad_b2.to(dtype), None, None + + +class FPDT_LogitsLoss(torch.autograd.Function): + generate_vmap_rule = False + + @staticmethod + def forward(ctx: Any, lm_output, labels, logit_weights, rank, spg_size, spg, num_chunk): + labels = labels.t() + chunk_size = lm_output.shape[0] // num_chunk + assert chunk_size * num_chunk == lm_output.shape[0] + batch_size, local_seq_len = lm_output.shape[1], lm_output.shape[0] + loss = torch.empty((batch_size, local_seq_len), dtype=torch.float, device=lm_output.device) + + ctx.num_chunk = num_chunk + ctx.chunk_size = chunk_size + ctx.device = lm_output.device + ctx.dtype = lm_output.dtype + + ctx.rank = rank + ctx.local_seq_len = local_seq_len + with torch.no_grad(): + for i in range(num_chunk): + st = i * chunk_size + ed = st + chunk_size + logits_chunk = torch.matmul(lm_output[st:ed], logit_weights.t()).float() + + vocab_size = logits_chunk.size(2) + # nll + softmax = torch.nn.functional.softmax(logits_chunk, dim=-1) + loss_chunk = torch.nn.functional.nll_loss(softmax.log().reshape(-1, vocab_size).contiguous(), labels[st:ed, :].reshape(-1).contiguous(), reduction='none') + loss[:, st:ed] = loss_chunk.reshape(chunk_size, batch_size).t() + + del logits_chunk + ctx.save_for_backward(lm_output.to('cpu'), labels) + ctx.logit_weights = logit_weights + + seqlen = local_seq_len * spg_size + batch_size = loss.size(0) + loss = loss.t().contiguous() + loss_all = torch.empty(seqlen, batch_size, dtype=loss.dtype, device=loss.device).contiguous() + + if version.parse(torch.__version__) >= version.parse('1.13'): + torch.distributed.all_gather_into_tensor(loss_all, loss, group=spg) + else: + torch.distributed._all_gather_base(loss_all, loss, group=spg) + + return loss_all + + @staticmethod + def backward(ctx, grad_output): + lm_output, labels = ctx.saved_tensors + logit_weights = ctx.logit_weights + device = ctx.device + dtype = ctx.dtype + num_chunk = ctx.num_chunk + chunk_size = ctx.chunk_size + + rank = ctx.rank + local_seq_len = ctx.local_seq_len + + grad_output = grad_output[rank*local_seq_len:(rank+1)*local_seq_len] + grad_lm_output = [None for _ in range(num_chunk)] + grad_logit_weights = torch.zeros(logit_weights.shape, device=grad_output.device, dtype=torch.float) + for i in range(num_chunk): + st = i * chunk_size + ed = st + chunk_size + lm_output_chunk = lm_output[st:ed].to(device) + logits_chunk = torch.matmul(lm_output_chunk, logit_weights.t()).float() + + # nll + softmax = torch.nn.functional.softmax(logits_chunk, dim=-1) + vocab_size = logits_chunk.size(2) + + grad_input = softmax + grad_2d = grad_input.reshape(-1, vocab_size).contiguous() + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], + device=device) + + grad_2d[arange_1d, labels[st:ed, :].reshape(-1).contiguous()] -= 1 + grad_input.mul_(grad_output[:chunk_size, :].unsqueeze(dim=-1)) + grad_input = grad_input.to(dtype) + + grad_output = grad_output[chunk_size:].contiguous() + + grad_lm_output_chunk = torch.matmul(grad_input, logit_weights) + grad_lm_output[i] = grad_lm_output_chunk + + grad_logit_weights.add_(torch.matmul(grad_input.reshape(-1, grad_input.shape[2]).t(), lm_output_chunk.reshape(-1, lm_output_chunk.shape[2]))) + + return torch.cat(grad_lm_output, dim=0).to(dtype), None, grad_logit_weights.to(dtype), None, None, None, None \ No newline at end of file diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index e809fe1118b5..b55205fc7e33 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -10,6 +10,34 @@ import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator +from einops import rearrange + + +def _rotate_half(x): + """ + change sign so the last dimension becomes [-odd, +even] + """ + x = rearrange(x, '... (j d) -> ... j d', j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(t, freqs_cos, freqs_sin): + """ + input tensor t is of shape [seq_length, ..., dim] + rotary positional embeding tensor freqs is of shape [seq_length, ..., dim] + check https://kexue.fm/archives/8265 for detailed formulas + """ + rot_dim = freqs_cos.shape[-1] + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + t = (t * freqs_cos) + (_rotate_half(t) * freqs_sin) + + res = t if t_pass.shape[-1] == 0 else torch.cat((t, t_pass), dim=-1) + return res def post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, seq_len, num_head, head_dim): @@ -168,7 +196,8 @@ def __init__( self.gather_idx = gather_idx self.sp_overlap_comm = False self.overlap_handles = None - self.sp_stream = sp_stream + self.sp_stream = None + # self.sp_stream = sp_stream if sp_stream is not None: self.overlap_handles = {} self.sp_overlap_comm = True @@ -178,7 +207,7 @@ def layer_sync(self, layer): if self.sp_overlap_comm and hasattr(layer, 'done_event'): self.dafult_stream.wait_event(layer.done_event) - def forward(self, query: Tensor, key: Tensor, value: Tensor, batch_dim_idx: int, *args: Any, **kwargs) -> Tensor: + def forward(self, query: Tensor, key: Tensor, value: Tensor, batch_dim_idx: int, rotary_pos_emb=None, *args: Any, **kwargs) -> Tensor: """ forward Arguments: @@ -233,6 +262,10 @@ def pre_hook_fun(grad): grad_fn_k.register_prehook(bwd_hook(layer_type='k')) #out shape : e.g., [s:h/p:] + if rotary_pos_emb is not None: + pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3) + query_layer = apply_rotary_pos_emb(query_layer, pos_emb_cos, pos_emb_sin) + key_layer = apply_rotary_pos_emb(key_layer, pos_emb_cos, pos_emb_sin) context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) From 69f38928fb3cbc0fd94830b9510d7da28c004ce3 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Tue, 24 Sep 2024 13:50:01 -0400 Subject: [PATCH 05/54] modify streams --- deepspeed/sequence/fpdt_layer.py | 39 +++++++++++++++----------------- deepspeed/sequence/layer.py | 3 +-- 2 files changed, 19 insertions(+), 23 deletions(-) diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index 7c88468fc947..133a18e381da 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -64,7 +64,6 @@ def apply_rotary_pos_emb_backward(grad_output, freqs_cos, freqs_sin): return grad -# @torch.jit.script def _update_out_and_lse( out: torch.Tensor, lse: torch.Tensor, @@ -568,7 +567,7 @@ def forward(ctx: Any, qkv_chunk = torch.matmul(layernorm_output[:chunk_size], qkv_linear_weight.t()) + qkv_linear_bias # torch.Size([18126, 1, 12288]) - with torch.cuda.stream(general_offload_stream): + with get_accelerator().stream(general_offload_stream): layernorm_output_cpu.append(SequenceChunk(layernorm_output[:chunk_size])) layernorm_output = layernorm_output[chunk_size:] @@ -583,7 +582,7 @@ def forward(ctx: Any, v_chunk = qkv_chunk[:, :, projection_size+kv_projection_size:].contiguous().reshape(qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd v_chunk = single_all_to_all(v_chunk, scatter_idx, gather_idx, spg) - torch.distributed.barrier() # torch.cuda.synchronize() + torch.distributed.barrier() # get_accelerator().synchronize() pos_emb_cos_chunk = pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] pos_emb_sin_chunk = pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] @@ -593,7 +592,7 @@ def forward(ctx: Any, compute_stream.wait_stream(offload_stream) compute_stream.synchronize() - with torch.cuda.stream(offload_stream): + with get_accelerator().stream(offload_stream): global_q.append(SequenceChunk(q_chunk, is_in_use=True)) global_k.append(SequenceChunk(k_chunk, is_in_use=True)) global_v.append(SequenceChunk(v_chunk, is_in_use=True)) @@ -604,7 +603,7 @@ def forward(ctx: Any, cur_attn_lse = None for k_i in range(len(global_k)): causal_chunk = i == k_i - with torch.cuda.stream(compute_stream): + with get_accelerator().stream(compute_stream): block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( global_q[q_compute_chunk_idx].get_gpu_chunk(), global_k[kv_compute_chunk_idx].get_gpu_chunk(), @@ -630,12 +629,12 @@ def forward(ctx: Any, can_offload_kv = False else: if next_kv_compute_chunk_idx != (len(global_k) - 1): - with torch.cuda.stream(offload_stream): + with get_accelerator().stream(offload_stream): global_k[next_kv_compute_chunk_idx].load_to_gpu() global_v[next_kv_compute_chunk_idx].load_to_gpu() if i == num_chunks - 1 and k_i == num_chunks - 1: - with torch.cuda.stream(offload_stream): + with get_accelerator().stream(offload_stream): global_q[0].load_to_gpu() global_k[0].load_to_gpu() global_v[0].load_to_gpu() @@ -655,7 +654,7 @@ def forward(ctx: Any, all2all_output = single_all_to_all(cur_attn_output.to(ctx.dtype).contiguous(), gather_idx, scatter_idx, spg) final_output.append(all2all_output) - with torch.cuda.stream(general_offload_stream): + with get_accelerator().stream(general_offload_stream): global_o.append(SequenceChunk(cur_attn_output.to(ctx.dtype))) global_lse.append(SequenceChunk(cur_attn_lse[:, :, :, 0].permute(0, 2, 1).contiguous())) @@ -724,16 +723,16 @@ def backward(ctx, grad_output): kv_compute_chunk_idx = 0 last_q_accum_idx = 0 - with torch.cuda.stream(general_offload_stream): + with get_accelerator().stream(general_offload_stream): layernorm_output[0].load_to_gpu() grad_qkv_linear_weight = torch.zeros(qkv_linear_weight.shape, device=qkv_linear_weight.device, dtype=torch.float) grad_qkv_linear_bias = torch.zeros(qkv_linear_bias.shape, device=qkv_linear_weight.device, dtype=torch.float) grad_global_attn_output_chunk = single_all_to_all(grad_output[:, :chunk_size].contiguous(), scatter_idx, gather_idx, spg) - torch.cuda.synchronize() + get_accelerator().synchronize() grad_output = grad_output[:, chunk_size:] - with torch.cuda.stream(offload_stream): + with get_accelerator().stream(offload_stream): grad_global_attn_output[0] = SequenceChunk(grad_global_attn_output_chunk, is_in_use=True) dq = [SequenceChunk(torch.zeros(global_q[0].chunk_shape, dtype=torch.float, device=device), is_in_use=True)] + [SequenceChunk(torch.zeros(global_q[0].chunk_shape, dtype=torch.float, device='cpu', pin_memory=True), device) for _ in range(num_chunks - 1)] dk_accum = torch.zeros(global_k[0].chunk_shape, dtype=torch.float, device=device) @@ -751,7 +750,7 @@ def backward(ctx, grad_output): dk_this = torch.zeros(global_k[0].chunk_shape, dtype=dtype, device=device) dv_this = torch.zeros(global_v[0].chunk_shape, dtype=dtype, device=device) - with torch.cuda.stream(compute_stream): + with get_accelerator().stream(compute_stream): _flash_attn_backward( grad_global_attn_output[q_compute_chunk_idx].get_gpu_chunk(), global_q[q_compute_chunk_idx].get_gpu_chunk(), @@ -783,7 +782,7 @@ def backward(ctx, grad_output): if next_q_compute_chunk_idx == q_compute_chunk_idx: can_offload_q = False else: - with torch.cuda.stream(offload_stream): + with get_accelerator().stream(offload_stream): if i > 0 or q_i > 0: if can_offload_q and last_q_accum_idx != i: # the first q chunk calculate in the loop will be sent out, therefore we do not offload it dq[last_q_accum_idx].offload() @@ -803,14 +802,14 @@ def backward(ctx, grad_output): compute_stream.wait_stream(offload_stream) compute_stream.synchronize() - with torch.cuda.stream(compute_stream): + with get_accelerator().stream(compute_stream): dq[q_compute_chunk_idx].check_gpu_chunk() dq[q_compute_chunk_idx].gpu_chunk.add_(dq_this) dk_accum.add_(dk_this) dv_accum.add_(dv_this) offload_stream.wait_stream(compute_stream) - with torch.cuda.stream(offload_stream): + with get_accelerator().stream(offload_stream): dq[q_compute_chunk_idx].overwrite_to_cpu() if can_offload_q: @@ -836,9 +835,9 @@ def backward(ctx, grad_output): general_offload_stream.synchronize() compute_stream.wait_stream(general_offload_stream) - torch.distributed.barrier() # torch.cuda.synchronize() + torch.distributed.barrier() # get_accelerator().synchronize() - with torch.cuda.stream(compute_stream): + with get_accelerator().stream(compute_stream): input_chunk = layernorm_output[i].get_gpu_chunk().reshape(-1, layernorm_output[i].chunk_shape[-1]) dq_accum = dq_accum.flatten(2).permute(1, 0, 2) @@ -867,11 +866,11 @@ def backward(ctx, grad_output): if i != (len(global_k) - 1): next_kv_compute_chunk_idx = kv_compute_chunk_idx + 1 - with torch.cuda.stream(offload_stream): + with get_accelerator().stream(offload_stream): global_k[next_kv_compute_chunk_idx].load_to_gpu() global_v[next_kv_compute_chunk_idx].load_to_gpu() - with torch.cuda.stream(general_offload_stream): + with get_accelerator().stream(general_offload_stream): layernorm_output[next_kv_compute_chunk_idx].load_to_gpu() compute_stream.wait_stream(offload_stream) @@ -986,11 +985,9 @@ def forward(self, return output, self.qkv_dense_bias if self.reture_bias else None -@torch.jit.script def bias_gelu(x): return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) -@torch.jit.script def bias_gelu_back(g, x): tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index b55205fc7e33..ff52f616b934 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -196,8 +196,7 @@ def __init__( self.gather_idx = gather_idx self.sp_overlap_comm = False self.overlap_handles = None - self.sp_stream = None - # self.sp_stream = sp_stream + self.sp_stream = sp_stream if sp_stream is not None: self.overlap_handles = {} self.sp_overlap_comm = True From 8ef9f5aafded30437bbaa67bca5d05af7bff3915 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Tue, 24 Sep 2024 13:52:34 -0400 Subject: [PATCH 06/54] modify streams --- deepspeed/sequence/fpdt_layer.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index 133a18e381da..89c9bd26a37b 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -21,29 +21,18 @@ def _rotate_half(x): - """ - change sign so the last dimension becomes [-odd, +even] - """ x = rearrange(x, '... (j d) -> ... j d', j=2) x1, x2 = x.unbind(dim=-2) return torch.cat((-x2, x1), dim=-1) def _rotate_half_backward(x): - """ - change sign so the last dimension becomes [-odd, +even] - """ x = rearrange(x, '... (j d) -> ... j d', j=2) x1, x2 = x.unbind(dim=-2) return torch.cat((x2, -x1), dim=-1) def apply_rotary_pos_emb(t, freqs_cos, freqs_sin): - """ - input tensor t is of shape [seq_length, ..., dim] - rotary positional embeding tensor freqs is of shape [seq_length, ..., dim] - check https://kexue.fm/archives/8265 for detailed formulas - """ rot_dim = freqs_cos.shape[-1] # ideally t_pass is empty so rotary pos embedding is applied to all tensor t t, t_pass = t[..., :rot_dim], t[..., rot_dim:] @@ -929,17 +918,6 @@ def forward(self, inference_params, rotary_pos_emb, cpu_offloading=True) -> Tensor: - """ forward - - Arguments: - query (Tensor): query input to the layer - key (Tensor): key input to the layer - value (Tensor): value input to the layer - args: other args - - Returns: - * output (Tensor): context output - """ self.num_chunks_attn = layernorm_output.shape[0] * dist.get_world_size(self.spg) // self.chunk_size if not cpu_offloading: From a55d1f5d89c63e15c2a72ba282a27048a9740584 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Sun, 6 Oct 2024 20:57:07 -0400 Subject: [PATCH 07/54] remove duplication of alltoall --- deepspeed/sequence/fpdt_layer.py | 89 ++++++++++++++++---------------- 1 file changed, 44 insertions(+), 45 deletions(-) diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index 89c9bd26a37b..d4e45e8a64a2 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -18,21 +18,33 @@ from packaging import version from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward from einops import rearrange +from .layer import single_all_to_all def _rotate_half(x): + """ + change sign so the last dimension becomes [-odd, +even] + """ x = rearrange(x, '... (j d) -> ... j d', j=2) x1, x2 = x.unbind(dim=-2) return torch.cat((-x2, x1), dim=-1) def _rotate_half_backward(x): + """ + change sign so the last dimension becomes [-odd, +even] + """ x = rearrange(x, '... (j d) -> ... j d', j=2) x1, x2 = x.unbind(dim=-2) return torch.cat((x2, -x1), dim=-1) def apply_rotary_pos_emb(t, freqs_cos, freqs_sin): + """ + input tensor t is of shape [seq_length, ..., dim] + rotary positional embeding tensor freqs is of shape [seq_length, ..., dim] + check https://kexue.fm/archives/8265 for detailed formulas + """ rot_dim = freqs_cos.shape[-1] # ideally t_pass is empty so rotary pos embedding is applied to all tensor t t, t_pass = t[..., :rot_dim], t[..., rot_dim:] @@ -53,6 +65,7 @@ def apply_rotary_pos_emb_backward(grad_output, freqs_cos, freqs_sin): return grad +# @torch.jit.script def _update_out_and_lse( out: torch.Tensor, lse: torch.Tensor, @@ -94,33 +107,6 @@ def update_out_and_lse( return out, lse -def single_all_to_all(input_, scatter_idx, gather_idx, group): - seq_world_size = dist.get_world_size(group) - if scatter_idx < 2: - bs, global_seq_len, local_head, head_dim = input_.shape - input_t = input_.reshape( - [bs, seq_world_size, global_seq_len // seq_world_size, local_head, head_dim] - ).contiguous() - input_t = input_t.permute(1, 0, 2, 3, 4).contiguous() - else: - bs, local_seq_len, total_head, head_dim = input_.shape - input_t = input_.reshape( - [bs, local_seq_len, seq_world_size, total_head // seq_world_size, head_dim] - ).contiguous() - input_t = input_t.permute(2, 0, 1, 3, 4).contiguous() - - output = torch.empty_like(input_t) - dist.all_to_all_single(output, input_t, group=group) - - if scatter_idx < 2: - output = output.permute(1, 2, 0, 3, 4).contiguous() - output = output.reshape(bs, global_seq_len // seq_world_size, seq_world_size * local_head, head_dim).contiguous() - else: - output = output.permute(1, 0, 2, 3, 4).contiguous() - output = output.reshape(bs, seq_world_size * local_seq_len, total_head // seq_world_size, head_dim).contiguous() - return output - - class FPDT_InputConstruct(torch.nn.Module): def __init__( self, @@ -243,18 +229,18 @@ def forward(ctx: Any, qkv_chunk = torch.matmul(layernorm_output[st:ed], qkv_linear_weight.t()) + qkv_linear_bias q_chunk = qkv_chunk[:, :, :projection_size].contiguous().reshape(qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd - q_chunk = single_all_to_all(q_chunk, scatter_idx, gather_idx, spg) + q_chunk = single_all_to_all(q_chunk, scatter_idx, gather_idx, 0, spg) global_q_chunk_len = q_chunk.shape[1] q_chunk = apply_rotary_pos_emb(q_chunk, pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) global_q.append(q_chunk) k_chunk = qkv_chunk[:, :, projection_size:projection_size+kv_projection_size].contiguous().reshape(qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd - k_chunk = single_all_to_all(k_chunk, scatter_idx, gather_idx, spg) + k_chunk = single_all_to_all(k_chunk, scatter_idx, gather_idx, 0, spg) k_chunk = apply_rotary_pos_emb(k_chunk, pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) global_k.append(k_chunk) v_chunk = qkv_chunk[:, :, projection_size+kv_projection_size:].contiguous().reshape(qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd - v_chunk = single_all_to_all(v_chunk, scatter_idx, gather_idx, spg) + v_chunk = single_all_to_all(v_chunk, scatter_idx, gather_idx, 0, spg) global_v.append(v_chunk) for k_i in range(len(global_k)): @@ -279,7 +265,7 @@ def forward(ctx: Any, for i in range(num_chunks): global_lse[i] = global_lse[i][:, :, :, 0].permute(0, 2, 1).contiguous() - output[i] = single_all_to_all(global_o[i].to(ctx.dtype).contiguous(), gather_idx, scatter_idx, spg) + output[i] = single_all_to_all(global_o[i].to(ctx.dtype).contiguous(), gather_idx, scatter_idx, 0, spg) output = torch.cat(output, dim=1) head_dim = output.shape[-1] @@ -337,7 +323,7 @@ def backward(ctx, grad_output): for i in range(num_chunks): st = chunk_size * i ed = st + chunk_size - grad_global_attn_output.append(single_all_to_all(grad_output[:, st:ed].contiguous(), scatter_idx, gather_idx, spg)) + grad_global_attn_output.append(single_all_to_all(grad_output[:, st:ed].contiguous(), scatter_idx, gather_idx, 0, spg)) del grad_output @@ -395,8 +381,8 @@ def backward(ctx, grad_output): dk_seq_len = dk[i].shape[1] dk[i] = apply_rotary_pos_emb_backward(dk[i].to(dtype), ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) dv[i] = dv[i].to(dtype) - dk[i] = single_all_to_all(dk[i].contiguous(), gather_idx, scatter_idx, spg) - dv[i] = single_all_to_all(dv[i].contiguous(), gather_idx, scatter_idx, spg) + dk[i] = single_all_to_all(dk[i].contiguous(), gather_idx, scatter_idx, 0, spg) + dv[i] = single_all_to_all(dv[i].contiguous(), gather_idx, scatter_idx, 0, spg) input_st = i * input_chunk_size input_ed = input_st + input_chunk_size @@ -421,7 +407,7 @@ def backward(ctx, grad_output): dq_seq_len = dq[i].shape[1] dq[i] = apply_rotary_pos_emb_backward(dq[i].to(dtype), ctx.pos_emb_cos[:, dq_seq_len * i:dq_seq_len * (i + 1)], ctx.pos_emb_sin[:, dq_seq_len * i:dq_seq_len * (i + 1)]) - dq[i] = single_all_to_all(dq[i].to(dtype).contiguous(), gather_idx, scatter_idx, spg) + dq[i] = single_all_to_all(dq[i].to(dtype).contiguous(), gather_idx, scatter_idx, 0, spg) input_chunk = layernorm_output[:input_chunk_size].reshape(-1, layernorm_output.shape[-1]) layernorm_output = layernorm_output[input_chunk_size:] @@ -562,16 +548,16 @@ def forward(ctx: Any, layernorm_output = layernorm_output[chunk_size:] q_chunk = qkv_chunk[:, :, :projection_size].contiguous().reshape(qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd - q_chunk = single_all_to_all(q_chunk, scatter_idx, gather_idx, spg) + q_chunk = single_all_to_all(q_chunk, scatter_idx, gather_idx, 0, spg) global_q_chunk_len = q_chunk.shape[1] k_chunk = qkv_chunk[:, :, projection_size:projection_size+kv_projection_size].contiguous().reshape(qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd - k_chunk = single_all_to_all(k_chunk, scatter_idx, gather_idx, spg) + k_chunk = single_all_to_all(k_chunk, scatter_idx, gather_idx, 0, spg) v_chunk = qkv_chunk[:, :, projection_size+kv_projection_size:].contiguous().reshape(qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd - v_chunk = single_all_to_all(v_chunk, scatter_idx, gather_idx, spg) + v_chunk = single_all_to_all(v_chunk, scatter_idx, gather_idx, 0, spg) - torch.distributed.barrier() # get_accelerator().synchronize() + torch.distributed.barrier() # torch.cuda.synchronize() pos_emb_cos_chunk = pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] pos_emb_sin_chunk = pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] @@ -641,7 +627,7 @@ def forward(ctx: Any, global_q[q_compute_chunk_idx].offload() q_compute_chunk_idx += 1 - all2all_output = single_all_to_all(cur_attn_output.to(ctx.dtype).contiguous(), gather_idx, scatter_idx, spg) + all2all_output = single_all_to_all(cur_attn_output.to(ctx.dtype).contiguous(), gather_idx, scatter_idx, 0, spg) final_output.append(all2all_output) with get_accelerator().stream(general_offload_stream): global_o.append(SequenceChunk(cur_attn_output.to(ctx.dtype))) @@ -717,7 +703,7 @@ def backward(ctx, grad_output): grad_qkv_linear_weight = torch.zeros(qkv_linear_weight.shape, device=qkv_linear_weight.device, dtype=torch.float) grad_qkv_linear_bias = torch.zeros(qkv_linear_bias.shape, device=qkv_linear_weight.device, dtype=torch.float) - grad_global_attn_output_chunk = single_all_to_all(grad_output[:, :chunk_size].contiguous(), scatter_idx, gather_idx, spg) + grad_global_attn_output_chunk = single_all_to_all(grad_output[:, :chunk_size].contiguous(), scatter_idx, gather_idx, 0, spg) get_accelerator().synchronize() grad_output = grad_output[:, chunk_size:] @@ -783,7 +769,7 @@ def backward(ctx, grad_output): grad_global_attn_output[next_q_compute_chunk_idx].load_to_gpu() if grad_global_attn_output[next_q_compute_chunk_idx] is None: - grad_global_attn_output_chunk = single_all_to_all(grad_output[:, :chunk_size].contiguous(), scatter_idx, gather_idx, spg) + grad_global_attn_output_chunk = single_all_to_all(grad_output[:, :chunk_size].contiguous(), scatter_idx, gather_idx, 0, spg) torch.distributed.barrier() grad_output = grad_output[:, chunk_size:] grad_global_attn_output[next_q_compute_chunk_idx] = SequenceChunk(grad_global_attn_output_chunk, is_in_use=True) @@ -818,9 +804,9 @@ def backward(ctx, grad_output): dk_accum = apply_rotary_pos_emb_backward(dk_accum.to(dtype), ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) dv_accum = dv_accum.to(dtype) - dq_accum = single_all_to_all(dq_accum.contiguous(), gather_idx, scatter_idx, spg) - dk_accum = single_all_to_all(dk_accum.contiguous(), gather_idx, scatter_idx, spg) - dv_accum = single_all_to_all(dv_accum.contiguous(), gather_idx, scatter_idx, spg) + dq_accum = single_all_to_all(dq_accum.contiguous(), gather_idx, scatter_idx, 0, spg) + dk_accum = single_all_to_all(dk_accum.contiguous(), gather_idx, scatter_idx, 0, spg) + dv_accum = single_all_to_all(dv_accum.contiguous(), gather_idx, scatter_idx, 0, spg) general_offload_stream.synchronize() compute_stream.wait_stream(general_offload_stream) @@ -918,6 +904,17 @@ def forward(self, inference_params, rotary_pos_emb, cpu_offloading=True) -> Tensor: + """ forward + + Arguments: + query (Tensor): query input to the layer + key (Tensor): key input to the layer + value (Tensor): value input to the layer + args: other args + + Returns: + * output (Tensor): context output + """ self.num_chunks_attn = layernorm_output.shape[0] * dist.get_world_size(self.spg) // self.chunk_size if not cpu_offloading: @@ -963,9 +960,11 @@ def forward(self, return output, self.qkv_dense_bias if self.reture_bias else None +@torch.jit.script def bias_gelu(x): return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) +@torch.jit.script def bias_gelu_back(g, x): tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 From 6bfd76f424f357170fcb70e4bfcc7b64eb16c04a Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Sun, 6 Oct 2024 21:02:47 -0400 Subject: [PATCH 08/54] remove duplication of pos --- deepspeed/sequence/fpdt_layer.py | 29 +---------------------------- 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index d4e45e8a64a2..39a166709259 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -18,16 +18,7 @@ from packaging import version from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward from einops import rearrange -from .layer import single_all_to_all - - -def _rotate_half(x): - """ - change sign so the last dimension becomes [-odd, +even] - """ - x = rearrange(x, '... (j d) -> ... j d', j=2) - x1, x2 = x.unbind(dim=-2) - return torch.cat((-x2, x1), dim=-1) +from .layer import single_all_to_all, apply_rotary_pos_emb def _rotate_half_backward(x): @@ -39,23 +30,6 @@ def _rotate_half_backward(x): return torch.cat((x2, -x1), dim=-1) -def apply_rotary_pos_emb(t, freqs_cos, freqs_sin): - """ - input tensor t is of shape [seq_length, ..., dim] - rotary positional embeding tensor freqs is of shape [seq_length, ..., dim] - check https://kexue.fm/archives/8265 for detailed formulas - """ - rot_dim = freqs_cos.shape[-1] - # ideally t_pass is empty so rotary pos embedding is applied to all tensor t - t, t_pass = t[..., :rot_dim], t[..., rot_dim:] - - # first part is cosine component - # second part is sine component, need to change signs with _rotate_half method - t = (t * freqs_cos) + (_rotate_half(t) * freqs_sin) - - res = t if t_pass.shape[-1] == 0 else torch.cat((t, t_pass), dim=-1) - return res - def apply_rotary_pos_emb_backward(grad_output, freqs_cos, freqs_sin): rot_dim = freqs_cos.shape[-1] @@ -65,7 +39,6 @@ def apply_rotary_pos_emb_backward(grad_output, freqs_cos, freqs_sin): return grad -# @torch.jit.script def _update_out_and_lse( out: torch.Tensor, lse: torch.Tensor, From 4eeadca61289e1d2ec2ba070f741e4c58a3363b5 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Sun, 6 Oct 2024 21:13:25 -0400 Subject: [PATCH 09/54] fix format --- .../activation_checkpointing/checkpointing.py | 3 +- deepspeed/sequence/fpdt_layer.py | 571 +++++++++--------- deepspeed/sequence/layer.py | 9 +- 3 files changed, 306 insertions(+), 277 deletions(-) diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index 866b444a35a5..85506a1532cc 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -369,7 +369,8 @@ def is_activation_to_checkpoint(item): Is an activation to be checkpointed """ global mp_size - extra_flag = (not hasattr(item, 'no_checkpointing')) or (hasattr(item, 'no_checkpointing') and item.no_checkpointing == False) + extra_flag = (not hasattr(item, 'no_checkpointing')) or (hasattr(item, 'no_checkpointing') + and item.no_checkpointing == False) return torch.is_tensor(item) and item.is_floating_point() and item.numel() >= mp_size and extra_flag diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index 39a166709259..611d992d809d 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -3,14 +3,10 @@ # DeepSpeed Team -from os import kill -from turtle import heading -from unittest import skip import torch from typing import Optional, Any, Tuple from torch import Tensor -from torch.nn import Module import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator @@ -22,15 +18,11 @@ def _rotate_half_backward(x): - """ - change sign so the last dimension becomes [-odd, +even] - """ x = rearrange(x, '... (j d) -> ... j d', j=2) x1, x2 = x.unbind(dim=-2) return torch.cat((x2, -x1), dim=-1) - def apply_rotary_pos_emb_backward(grad_output, freqs_cos, freqs_sin): rot_dim = freqs_cos.shape[-1] grad, grad_pass = grad_output[..., :rot_dim], grad_output[..., rot_dim:] @@ -71,9 +63,7 @@ def update_out_and_lse( lse = block_lse.permute(0, 2, 1).contiguous().unsqueeze(dim=-1).contiguous() elif slice_ is not None: slice_out, slice_lse = out[slice_], lse[slice_] - slice_out, slice_lse = _update_out_and_lse( - slice_out, slice_lse, block_out, block_lse - ) + slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse) out[slice_], lse[slice_] = slice_out, slice_lse else: out, lse = _update_out_and_lse(out, lse, block_out, block_lse) @@ -81,10 +71,8 @@ def update_out_and_lse( class FPDT_InputConstruct(torch.nn.Module): - def __init__( - self, - tokens, labels, loss_mask, attention_mask, position_ids, args, sp_size, sp_rank - ) -> None: + + def __init__(self, tokens, labels, loss_mask, attention_mask, position_ids, args, sp_size, sp_rank) -> None: super(FPDT_InputConstruct, self).__init__() self.tokens = tokens @@ -108,30 +96,31 @@ def __init__( self.local_seq_len = local_seq_len self.batch_size = batch_size self.device = tokens.device - + def generate(self): device = self.device totalChunks = self.global_seq_len // self.chunk_size token_chunk_idx = torch.arange(self.global_seq_len, device=device, dtype=torch.int) // self.chunk_size chunk_to_gpu = torch.arange(totalChunks, device=device, dtype=torch.int) chunk_to_gpu = chunk_to_gpu.reshape(self.num_chunk_per_gpu, -1).t().contiguous() - + gather_chunk = chunk_to_gpu.flatten().unsqueeze(1).contiguous() mask = gather_chunk == token_chunk_idx - + indices = mask.nonzero(as_tuple=False) gather_indices = indices[:, 0] token_chunk_indices = indices[:, 1] indices = torch.cat([token_chunk_indices[gather_indices == i] for i in range(gather_chunk.shape[0])]) load_balanced_loss_mask = self.loss_mask[:, indices] - indices = indices.reshape(-1, self.chunk_size)[self.num_chunk_per_gpu*self.sp_rank:self.num_chunk_per_gpu*(self.sp_rank + 1)].flatten().contiguous() + indices = indices.reshape(-1, self.chunk_size)[self.num_chunk_per_gpu * self.sp_rank:self.num_chunk_per_gpu * + (self.sp_rank + 1)].flatten().contiguous() load_balanced_tokens = self.tokens[:, indices] load_balanced_labels = self.labels[:, indices] load_balanced_attention_mask = self.attention_mask if self.attention_mask is not None else None load_balanced_position_ids = self.position_ids[:, indices] - + return load_balanced_tokens, load_balanced_labels, load_balanced_loss_mask, load_balanced_attention_mask, load_balanced_position_ids @@ -139,11 +128,11 @@ class _FPDTGPUAttentionImpl_(torch.autograd.Function): generate_vmap_rule = False @staticmethod - def forward(ctx: Any, + def forward(ctx: Any, layernorm_output, attention_mask, inference_params, - rotary_pos_emb, + rotary_pos_emb, spg, scatter_idx, gather_idx, @@ -154,14 +143,15 @@ def forward(ctx: Any, qkv_linear_weight, qkv_linear_bias, dropout, - num_chunks=8, cpu_offloading=True): - + num_chunks=8, + cpu_offloading=True): + do_save = layernorm_output.requires_grad pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3) ctx.pos_emb_cos = pos_emb_cos ctx.pos_emb_sin = pos_emb_sin - + with torch.no_grad(): per_gpu_seq_len = layernorm_output.shape[0] chunk_size = per_gpu_seq_len // num_chunks @@ -183,14 +173,14 @@ def forward(ctx: Any, global_k = [] global_v = [] - ctx.softmax_scale = hidden_size_per_attention_head ** (-0.5) - + ctx.softmax_scale = hidden_size_per_attention_head**(-0.5) + ctx.dropout_p = dropout ctx.window_size = (-1, -1) ctx.alibi_slopes = None batch_size = layernorm_output.shape[1] - + global_o = [None for _ in range(num_chunks)] global_lse = [None for _ in range(num_chunks)] @@ -198,40 +188,48 @@ def forward(ctx: Any, st = chunk_size * i ed = st + chunk_size - + qkv_chunk = torch.matmul(layernorm_output[st:ed], qkv_linear_weight.t()) + qkv_linear_bias - - q_chunk = qkv_chunk[:, :, :projection_size].contiguous().reshape(qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + + q_chunk = qkv_chunk[:, :, :projection_size].contiguous().reshape( + qkv_chunk.shape[0], qkv_chunk.shape[1], -1, + hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd q_chunk = single_all_to_all(q_chunk, scatter_idx, gather_idx, 0, spg) global_q_chunk_len = q_chunk.shape[1] - q_chunk = apply_rotary_pos_emb(q_chunk, pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) + q_chunk = apply_rotary_pos_emb(q_chunk, + pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], + pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) global_q.append(q_chunk) - - k_chunk = qkv_chunk[:, :, projection_size:projection_size+kv_projection_size].contiguous().reshape(qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + + k_chunk = qkv_chunk[:, :, projection_size:projection_size + kv_projection_size].contiguous().reshape( + qkv_chunk.shape[0], qkv_chunk.shape[1], -1, + hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd k_chunk = single_all_to_all(k_chunk, scatter_idx, gather_idx, 0, spg) - k_chunk = apply_rotary_pos_emb(k_chunk, pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) + k_chunk = apply_rotary_pos_emb(k_chunk, + pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], + pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) global_k.append(k_chunk) - v_chunk = qkv_chunk[:, :, projection_size+kv_projection_size:].contiguous().reshape(qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + v_chunk = qkv_chunk[:, :, projection_size + kv_projection_size:].contiguous().reshape( + qkv_chunk.shape[0], qkv_chunk.shape[1], -1, + hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd v_chunk = single_all_to_all(v_chunk, scatter_idx, gather_idx, 0, spg) global_v.append(v_chunk) - + for k_i in range(len(global_k)): causal_chunk = i == k_i - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( - global_q[i], - global_k[k_i], - global_v[k_i], - ctx.dropout_p, - ctx.softmax_scale, - causal=causal_chunk, - window_size=ctx.window_size, - softcap=0.0, - alibi_slopes=ctx.alibi_slopes, - return_softmax=False - ) + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(global_q[i], + global_k[k_i], + global_v[k_i], + ctx.dropout_p, + ctx.softmax_scale, + causal=causal_chunk, + window_size=ctx.window_size, + softcap=0.0, + alibi_slopes=ctx.alibi_slopes, + return_softmax=False) global_o[i], global_lse[i] = update_out_and_lse(global_o[i], global_lse[i], block_out, block_lse) - + global_o[i] = global_o[i].to(q_chunk.dtype) output = [None for i in range(num_chunks)] @@ -242,7 +240,7 @@ def forward(ctx: Any, output = torch.cat(output, dim=1) head_dim = output.shape[-1] - + if do_save: ctx.save_for_backward(layernorm_output) ctx.global_q = global_q @@ -258,9 +256,8 @@ def forward(ctx: Any, return output - @staticmethod - def backward(ctx, grad_output): + def backward(ctx, grad_output): num_chunks = ctx.num_chunks device = ctx.device @@ -286,25 +283,32 @@ def backward(ctx, grad_output): qkv_linear_weight = ctx.qkv_linear_weight qkv_linear_bias = ctx.qkv_linear_bias - + input_chunk_size = layernorm_output.shape[0] // num_chunks - grad_layernorm_output = [torch.zeros((input_chunk_size, layernorm_output.shape[1], layernorm_output.shape[2]), device=device, dtype=dtype) for _ in range(num_chunks)] - + grad_layernorm_output = [ + torch.zeros((input_chunk_size, layernorm_output.shape[1], layernorm_output.shape[2]), + device=device, + dtype=dtype) for _ in range(num_chunks) + ] + grad_global_attn_output = [] chunk_size = grad_output.shape[1] // num_chunks for i in range(num_chunks): st = chunk_size * i ed = st + chunk_size - grad_global_attn_output.append(single_all_to_all(grad_output[:, st:ed].contiguous(), scatter_idx, gather_idx, 0, spg)) - + grad_global_attn_output.append( + single_all_to_all(grad_output[:, st:ed].contiguous(), scatter_idx, gather_idx, 0, spg)) + del grad_output dq = [torch.zeros(global_q[0].shape, dtype=torch.float, device=device) for _ in range(num_chunks)] dk = [torch.zeros(global_q[0].shape, dtype=torch.float, device=device) for _ in range(num_chunks)] dv = [torch.zeros(global_q[0].shape, dtype=torch.float, device=device) for _ in range(num_chunks)] - grad_qkv_linear_weight = torch.zeros(qkv_linear_weight.shape, device=qkv_linear_weight.device, dtype=torch.float) + grad_qkv_linear_weight = torch.zeros(qkv_linear_weight.shape, + device=qkv_linear_weight.device, + dtype=torch.float) grad_qkv_linear_bias = torch.zeros(qkv_linear_bias.shape, device=qkv_linear_weight.device, dtype=torch.float) for i in range(num_chunks): @@ -326,60 +330,66 @@ def backward(ctx, grad_output): dq_this = torch.zeros(global_q[0].shape, dtype=dtype, device=device) dk_this = torch.zeros(global_k[0].shape, dtype=dtype, device=device) dv_this = torch.zeros(global_v[0].shape, dtype=dtype, device=device) - - _flash_attn_backward( - dout, - q_chunk, - k_chunk, - v_chunk, - attn_output_chunk, - lse_chunk, - dq_this, - dk_this, - dv_this, - dropout_p, - softmax_scale, - causal_chunk, - window_size, - softcap=0.0, - alibi_slopes=alibi_slopes, - deterministic=False, - rng_state=None - ) - + + _flash_attn_backward(dout, + q_chunk, + k_chunk, + v_chunk, + attn_output_chunk, + lse_chunk, + dq_this, + dk_this, + dv_this, + dropout_p, + softmax_scale, + causal_chunk, + window_size, + softcap=0.0, + alibi_slopes=alibi_slopes, + deterministic=False, + rng_state=None) + dq[q_i].add_(dq_this.to(torch.float)) dk[i].add_(dk_this.to(torch.float)) dv[i].add_(dv_this.to(torch.float)) dk_seq_len = dk[i].shape[1] - dk[i] = apply_rotary_pos_emb_backward(dk[i].to(dtype), ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) + dk[i] = apply_rotary_pos_emb_backward(dk[i].to(dtype), + ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], + ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) dv[i] = dv[i].to(dtype) dk[i] = single_all_to_all(dk[i].contiguous(), gather_idx, scatter_idx, 0, spg) dv[i] = single_all_to_all(dv[i].contiguous(), gather_idx, scatter_idx, 0, spg) input_st = i * input_chunk_size input_ed = input_st + input_chunk_size - + input_chunk = layernorm_output[input_st:input_ed].reshape(-1, layernorm_output.shape[-1]) dk[i] = dk[i].flatten(2).permute(1, 0, 2) dv[i] = dv[i].flatten(2).permute(1, 0, 2) l, b = dk[i].shape[0], dk[i].shape[1] - grad_qkv_linear_weight[projection_size:projection_size+kv_projection_size].add_(torch.matmul(dk[i].reshape(l*b, -1).t(), input_chunk)) - grad_qkv_linear_weight[projection_size+kv_projection_size:].add_(torch.matmul(dv[i].reshape(l*b, -1).t(), input_chunk)) - grad_qkv_linear_bias[projection_size:projection_size+kv_projection_size].add_(dk[i].sum(0).sum(0)) - grad_qkv_linear_bias[projection_size+kv_projection_size:].add_(dv[i].sum(0).sum(0)) - - grad_layernorm_output[i].add_(torch.matmul(dk[i], qkv_linear_weight[projection_size:projection_size+kv_projection_size])) - grad_layernorm_output[i].add_(torch.matmul(dv[i], qkv_linear_weight[projection_size+kv_projection_size:])) + grad_qkv_linear_weight[projection_size:projection_size + kv_projection_size].add_( + torch.matmul(dk[i].reshape(l * b, -1).t(), input_chunk)) + grad_qkv_linear_weight[projection_size + kv_projection_size:].add_( + torch.matmul(dv[i].reshape(l * b, -1).t(), input_chunk)) + grad_qkv_linear_bias[projection_size:projection_size + kv_projection_size].add_(dk[i].sum(0).sum(0)) + grad_qkv_linear_bias[projection_size + kv_projection_size:].add_(dv[i].sum(0).sum(0)) + + grad_layernorm_output[i].add_( + torch.matmul(dk[i], qkv_linear_weight[projection_size:projection_size + kv_projection_size])) + grad_layernorm_output[i].add_(torch.matmul(dv[i], + qkv_linear_weight[projection_size + kv_projection_size:])) dk[i] = None dv[i] = None - + for i in range(num_chunks): dq_seq_len = dq[i].shape[1] - dq[i] = apply_rotary_pos_emb_backward(dq[i].to(dtype), ctx.pos_emb_cos[:, dq_seq_len * i:dq_seq_len * (i + 1)], ctx.pos_emb_sin[:, dq_seq_len * i:dq_seq_len * (i + 1)]) - + dq[i] = apply_rotary_pos_emb_backward(dq[i].to(dtype), + ctx.pos_emb_cos[:, dq_seq_len * i:dq_seq_len * (i + 1)], + ctx.pos_emb_sin[:, dq_seq_len * i:dq_seq_len * (i + 1)]) + dq[i] = single_all_to_all(dq[i].to(dtype).contiguous(), gather_idx, scatter_idx, 0, spg) input_chunk = layernorm_output[:input_chunk_size].reshape(-1, layernorm_output.shape[-1]) @@ -387,25 +397,27 @@ def backward(ctx, grad_output): dq[i] = dq[i].flatten(2).permute(1, 0, 2) l, b = dq[i].shape[0], dq[i].shape[1] - grad_qkv_linear_weight[:projection_size].add_(torch.matmul(dq[i].reshape(l*b, -1).t(), input_chunk)) + grad_qkv_linear_weight[:projection_size].add_(torch.matmul(dq[i].reshape(l * b, -1).t(), input_chunk)) grad_qkv_linear_bias[:projection_size].add_(dq[i].sum(0).sum(0)) grad_layernorm_output[i].add_(torch.matmul(dq[i], qkv_linear_weight[:projection_size])) dq[i] = None - - - return torch.cat(grad_layernorm_output, dim=0).to(dtype), None, None, None, None, None, None, None, None, None, None, grad_qkv_linear_weight.to(dtype), grad_qkv_linear_bias.to(dtype), None, None, None + return torch.cat( + grad_layernorm_output, + dim=0).to(dtype), None, None, None, None, None, None, None, None, None, None, grad_qkv_linear_weight.to( + dtype), grad_qkv_linear_bias.to(dtype), None, None, None class SequenceChunk: + def __init__(self, chunk: torch.Tensor, device=None, is_in_use=False): - + self.chunk_shape = chunk.shape self.chunk_dtype = chunk.dtype self.device = chunk.device if device is None else device - + cpu_chunk = torch.empty(chunk.shape, dtype=chunk.dtype, device='cpu', pin_memory=True) if chunk.is_cuda: cpu_chunk.copy_(chunk, non_blocking=True) @@ -415,7 +427,7 @@ def __init__(self, chunk: torch.Tensor, device=None, is_in_use=False): self.cpu_chunk = cpu_chunk self.gpu_chunk = chunk if is_in_use else None - + def load_to_gpu(self): assert self.gpu_chunk is None if self.gpu_chunk is not None: @@ -428,9 +440,11 @@ def load_to_gpu(self): def get_gpu_chunk(self): assert self.gpu_chunk is not None and self.gpu_chunk.device == self.device return self.gpu_chunk - - def check_gpu_chunk(self,): - assert (self.gpu_chunk is not None) and (self.gpu_chunk.device == self.device), f"gpu_chunk {self.gpu_chunk is not None} shound be on {self.device}, but it is now on {self.gpu_chunk.device}" + + def check_gpu_chunk(self, ): + assert (self.gpu_chunk is not None) and ( + self.gpu_chunk.device == self.device + ), f"gpu_chunk {self.gpu_chunk is not None} shound be on {self.device}, but it is now on {self.gpu_chunk.device}" return True def offload(self): @@ -447,11 +461,11 @@ class _FPDTGPUOffloadingAttentionImpl_(torch.autograd.Function): generate_vmap_rule = False @staticmethod - def forward(ctx: Any, + def forward(ctx: Any, layernorm_output, attention_mask, inference_params, - rotary_pos_emb, + rotary_pos_emb, spg, scatter_idx, gather_idx, @@ -462,8 +476,9 @@ def forward(ctx: Any, qkv_linear_weight, qkv_linear_bias, dropout, - num_chunks=8, cpu_offloading=True): - + num_chunks=8, + cpu_offloading=True): + do_save = layernorm_output.requires_grad pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3) @@ -491,14 +506,14 @@ def forward(ctx: Any, global_k = [] global_v = [] - ctx.softmax_scale = hidden_size_per_attention_head ** (-0.5) - + ctx.softmax_scale = hidden_size_per_attention_head**(-0.5) + ctx.dropout_p = dropout ctx.window_size = (-1, -1) ctx.alibi_slopes = None batch_size = layernorm_output.shape[1] - + global_o = [] global_lse = [] @@ -513,31 +528,38 @@ def forward(ctx: Any, kv_compute_chunk_idx = 0 for i in range(num_chunks): - qkv_chunk = torch.matmul(layernorm_output[:chunk_size], qkv_linear_weight.t()) + qkv_linear_bias # torch.Size([18126, 1, 12288]) + qkv_chunk = torch.matmul(layernorm_output[:chunk_size], + qkv_linear_weight.t()) + qkv_linear_bias # torch.Size([18126, 1, 12288]) with get_accelerator().stream(general_offload_stream): layernorm_output_cpu.append(SequenceChunk(layernorm_output[:chunk_size])) layernorm_output = layernorm_output[chunk_size:] - q_chunk = qkv_chunk[:, :, :projection_size].contiguous().reshape(qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + q_chunk = qkv_chunk[:, :, :projection_size].contiguous().reshape( + qkv_chunk.shape[0], qkv_chunk.shape[1], -1, + hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd q_chunk = single_all_to_all(q_chunk, scatter_idx, gather_idx, 0, spg) global_q_chunk_len = q_chunk.shape[1] - - k_chunk = qkv_chunk[:, :, projection_size:projection_size+kv_projection_size].contiguous().reshape(qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + + k_chunk = qkv_chunk[:, :, projection_size:projection_size + kv_projection_size].contiguous().reshape( + qkv_chunk.shape[0], qkv_chunk.shape[1], -1, + hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd k_chunk = single_all_to_all(k_chunk, scatter_idx, gather_idx, 0, spg) - v_chunk = qkv_chunk[:, :, projection_size+kv_projection_size:].contiguous().reshape(qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd + v_chunk = qkv_chunk[:, :, projection_size + kv_projection_size:].contiguous().reshape( + qkv_chunk.shape[0], qkv_chunk.shape[1], -1, + hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd v_chunk = single_all_to_all(v_chunk, scatter_idx, gather_idx, 0, spg) - - torch.distributed.barrier() # torch.cuda.synchronize() - + + torch.distributed.barrier() + pos_emb_cos_chunk = pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] pos_emb_sin_chunk = pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] q_chunk = apply_rotary_pos_emb(q_chunk, pos_emb_cos_chunk, pos_emb_sin_chunk) k_chunk = apply_rotary_pos_emb(k_chunk, pos_emb_cos_chunk, pos_emb_sin_chunk) - + compute_stream.wait_stream(offload_stream) compute_stream.synchronize() with get_accelerator().stream(offload_stream): @@ -553,18 +575,18 @@ def forward(ctx: Any, causal_chunk = i == k_i with get_accelerator().stream(compute_stream): block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( - global_q[q_compute_chunk_idx].get_gpu_chunk(), - global_k[kv_compute_chunk_idx].get_gpu_chunk(), - global_v[kv_compute_chunk_idx].get_gpu_chunk(), - ctx.dropout_p, - ctx.softmax_scale, - causal=causal_chunk, - window_size=ctx.window_size, - softcap=0.0, - alibi_slopes=ctx.alibi_slopes, - return_softmax=False - ) - cur_attn_output, cur_attn_lse = update_out_and_lse(cur_attn_output, cur_attn_lse, block_out, block_lse) + global_q[q_compute_chunk_idx].get_gpu_chunk(), + global_k[kv_compute_chunk_idx].get_gpu_chunk(), + global_v[kv_compute_chunk_idx].get_gpu_chunk(), + ctx.dropout_p, + ctx.softmax_scale, + causal=causal_chunk, + window_size=ctx.window_size, + softcap=0.0, + alibi_slopes=ctx.alibi_slopes, + return_softmax=False) + cur_attn_output, cur_attn_lse = update_out_and_lse(cur_attn_output, cur_attn_lse, block_out, + block_lse) can_offload_kv = True if k_i != (len(global_k) - 1) or i != (num_chunks - 1): @@ -591,16 +613,17 @@ def forward(ctx: Any, compute_stream.wait_stream(offload_stream) compute_stream.synchronize() - + if can_offload_kv: global_k[kv_compute_chunk_idx].offload() global_v[kv_compute_chunk_idx].offload() kv_compute_chunk_idx = next_kv_compute_chunk_idx - + global_q[q_compute_chunk_idx].offload() q_compute_chunk_idx += 1 - - all2all_output = single_all_to_all(cur_attn_output.to(ctx.dtype).contiguous(), gather_idx, scatter_idx, 0, spg) + + all2all_output = single_all_to_all( + cur_attn_output.to(ctx.dtype).contiguous(), gather_idx, scatter_idx, 0, spg) final_output.append(all2all_output) with get_accelerator().stream(general_offload_stream): global_o.append(SequenceChunk(cur_attn_output.to(ctx.dtype))) @@ -610,9 +633,9 @@ def forward(ctx: Any, compute_stream.synchronize() final_output = torch.cat(final_output, dim=1) - + head_dim = final_output.shape[-1] - + if do_save: ctx.layernorm_output = layernorm_output_cpu ctx.global_q = global_q @@ -628,9 +651,8 @@ def forward(ctx: Any, return final_output - @staticmethod - def backward(ctx, grad_output): + def backward(ctx, grad_output): num_chunks = ctx.num_chunks device = grad_output.device dtype = ctx.dtype @@ -646,7 +668,7 @@ def backward(ctx, grad_output): kv_projection_size = ctx.kv_projection_size layernorm_output = ctx.layernorm_output - + global_q = ctx.global_q global_k = ctx.global_k global_v = ctx.global_v @@ -655,34 +677,46 @@ def backward(ctx, grad_output): qkv_linear_weight = ctx.qkv_linear_weight qkv_linear_bias = ctx.qkv_linear_bias - + offload_stream = get_accelerator().Stream() - general_offload_stream = torch.cuda.Stream() + general_offload_stream = get_accelerator().Stream() compute_stream = get_accelerator().default_stream() chunk_size = grad_output.shape[1] // num_chunks assert chunk_size == layernorm_output[0].cpu_chunk.shape[0] - grad_layernorm_output = [torch.zeros(layernorm_output[0].chunk_shape, device=device, dtype=dtype) for _ in range(num_chunks)] - + grad_layernorm_output = [ + torch.zeros(layernorm_output[0].chunk_shape, device=device, dtype=dtype) for _ in range(num_chunks) + ] + grad_global_attn_output = [None for _ in range(num_chunks)] q_compute_chunk_idx = 0 kv_compute_chunk_idx = 0 last_q_accum_idx = 0 - + with get_accelerator().stream(general_offload_stream): layernorm_output[0].load_to_gpu() - grad_qkv_linear_weight = torch.zeros(qkv_linear_weight.shape, device=qkv_linear_weight.device, dtype=torch.float) - grad_qkv_linear_bias = torch.zeros(qkv_linear_bias.shape, device=qkv_linear_weight.device, dtype=torch.float) - - grad_global_attn_output_chunk = single_all_to_all(grad_output[:, :chunk_size].contiguous(), scatter_idx, gather_idx, 0, spg) + grad_qkv_linear_weight = torch.zeros(qkv_linear_weight.shape, + device=qkv_linear_weight.device, + dtype=torch.float) + grad_qkv_linear_bias = torch.zeros(qkv_linear_bias.shape, + device=qkv_linear_weight.device, + dtype=torch.float) + + grad_global_attn_output_chunk = single_all_to_all(grad_output[:, :chunk_size].contiguous(), scatter_idx, + gather_idx, 0, spg) get_accelerator().synchronize() grad_output = grad_output[:, chunk_size:] with get_accelerator().stream(offload_stream): grad_global_attn_output[0] = SequenceChunk(grad_global_attn_output_chunk, is_in_use=True) - dq = [SequenceChunk(torch.zeros(global_q[0].chunk_shape, dtype=torch.float, device=device), is_in_use=True)] + [SequenceChunk(torch.zeros(global_q[0].chunk_shape, dtype=torch.float, device='cpu', pin_memory=True), device) for _ in range(num_chunks - 1)] + dq = [ + SequenceChunk(torch.zeros(global_q[0].chunk_shape, dtype=torch.float, device=device), is_in_use=True) + ] + [ + SequenceChunk(torch.zeros(global_q[0].chunk_shape, dtype=torch.float, device='cpu', pin_memory=True), + device) for _ in range(num_chunks - 1) + ] dk_accum = torch.zeros(global_k[0].chunk_shape, dtype=torch.float, device=device) dv_accum = torch.zeros(global_v[0].chunk_shape, dtype=torch.float, device=device) @@ -697,27 +731,25 @@ def backward(ctx, grad_output): dq_this = torch.zeros(global_q[0].chunk_shape, dtype=dtype, device=device) dk_this = torch.zeros(global_k[0].chunk_shape, dtype=dtype, device=device) dv_this = torch.zeros(global_v[0].chunk_shape, dtype=dtype, device=device) - + with get_accelerator().stream(compute_stream): - _flash_attn_backward( - grad_global_attn_output[q_compute_chunk_idx].get_gpu_chunk(), - global_q[q_compute_chunk_idx].get_gpu_chunk(), - global_k[kv_compute_chunk_idx].get_gpu_chunk(), - global_v[kv_compute_chunk_idx].get_gpu_chunk(), - attn_output[q_compute_chunk_idx].get_gpu_chunk(), - lse[q_compute_chunk_idx].get_gpu_chunk(), - dq_this, - dk_this, - dv_this, - dropout_p, - softmax_scale, - causal_chunk, - window_size, - softcap=0.0, - alibi_slopes=alibi_slopes, - deterministic=False, - rng_state=None - ) + _flash_attn_backward(grad_global_attn_output[q_compute_chunk_idx].get_gpu_chunk(), + global_q[q_compute_chunk_idx].get_gpu_chunk(), + global_k[kv_compute_chunk_idx].get_gpu_chunk(), + global_v[kv_compute_chunk_idx].get_gpu_chunk(), + attn_output[q_compute_chunk_idx].get_gpu_chunk(), + lse[q_compute_chunk_idx].get_gpu_chunk(), + dq_this, + dk_this, + dv_this, + dropout_p, + softmax_scale, + causal_chunk, + window_size, + softcap=0.0, + alibi_slopes=alibi_slopes, + deterministic=False, + rng_state=None) if i != (len(global_k) - 1): if q_i != (len(global_q) - 1): @@ -728,11 +760,11 @@ def backward(ctx, grad_output): can_offload_q = True if next_q_compute_chunk_idx == q_compute_chunk_idx: - can_offload_q = False + can_offload_q = False else: with get_accelerator().stream(offload_stream): if i > 0 or q_i > 0: - if can_offload_q and last_q_accum_idx != i: # the first q chunk calculate in the loop will be sent out, therefore we do not offload it + if can_offload_q and last_q_accum_idx != i: # the first q chunk calculate in the loop will be sent out, therefore we do not offload it dq[last_q_accum_idx].offload() dq[next_q_compute_chunk_idx].load_to_gpu() global_q[next_q_compute_chunk_idx].load_to_gpu() @@ -740,19 +772,21 @@ def backward(ctx, grad_output): lse[next_q_compute_chunk_idx].load_to_gpu() if grad_global_attn_output[next_q_compute_chunk_idx] is not None: grad_global_attn_output[next_q_compute_chunk_idx].load_to_gpu() - + if grad_global_attn_output[next_q_compute_chunk_idx] is None: - grad_global_attn_output_chunk = single_all_to_all(grad_output[:, :chunk_size].contiguous(), scatter_idx, gather_idx, 0, spg) + grad_global_attn_output_chunk = single_all_to_all(grad_output[:, :chunk_size].contiguous(), + scatter_idx, gather_idx, 0, spg) torch.distributed.barrier() grad_output = grad_output[:, chunk_size:] - grad_global_attn_output[next_q_compute_chunk_idx] = SequenceChunk(grad_global_attn_output_chunk, is_in_use=True) + grad_global_attn_output[next_q_compute_chunk_idx] = SequenceChunk( + grad_global_attn_output_chunk, is_in_use=True) compute_stream.wait_stream(offload_stream) compute_stream.synchronize() - + with get_accelerator().stream(compute_stream): dq[q_compute_chunk_idx].check_gpu_chunk() - dq[q_compute_chunk_idx].gpu_chunk.add_(dq_this) + dq[q_compute_chunk_idx].gpu_chunk.add_(dq_this) dk_accum.add_(dk_this) dv_accum.add_(dv_this) @@ -765,25 +799,29 @@ def backward(ctx, grad_output): attn_output[q_compute_chunk_idx].offload() lse[q_compute_chunk_idx].offload() grad_global_attn_output[q_compute_chunk_idx].offload() - + last_q_accum_idx = q_compute_chunk_idx q_compute_chunk_idx = next_q_compute_chunk_idx - + compute_stream.wait_stream(offload_stream) compute_stream.synchronize() - + dk_seq_len = dk_accum.shape[1] - dq_accum = apply_rotary_pos_emb_backward(dq[kv_compute_chunk_idx].get_gpu_chunk().to(dtype), ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) - dk_accum = apply_rotary_pos_emb_backward(dk_accum.to(dtype), ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) + dq_accum = apply_rotary_pos_emb_backward(dq[kv_compute_chunk_idx].get_gpu_chunk().to(dtype), + ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], + ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) + dk_accum = apply_rotary_pos_emb_backward(dk_accum.to(dtype), + ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], + ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) dv_accum = dv_accum.to(dtype) dq_accum = single_all_to_all(dq_accum.contiguous(), gather_idx, scatter_idx, 0, spg) dk_accum = single_all_to_all(dk_accum.contiguous(), gather_idx, scatter_idx, 0, spg) dv_accum = single_all_to_all(dv_accum.contiguous(), gather_idx, scatter_idx, 0, spg) - + general_offload_stream.synchronize() compute_stream.wait_stream(general_offload_stream) - torch.distributed.barrier() # get_accelerator().synchronize() + torch.distributed.barrier() with get_accelerator().stream(compute_stream): input_chunk = layernorm_output[i].get_gpu_chunk().reshape(-1, layernorm_output[i].chunk_shape[-1]) @@ -791,20 +829,25 @@ def backward(ctx, grad_output): dq_accum = dq_accum.flatten(2).permute(1, 0, 2) dk_accum = dk_accum.flatten(2).permute(1, 0, 2) dv_accum = dv_accum.flatten(2).permute(1, 0, 2) - + l, b = dk_accum.shape[0], dk_accum.shape[1] - grad_qkv_linear_weight[:projection_size].add_(torch.matmul(dq_accum.reshape(l*b, -1).t(), input_chunk)) - grad_qkv_linear_weight[projection_size:projection_size+kv_projection_size].add_(torch.matmul(dk_accum.reshape(l*b, -1).t(), input_chunk)) - grad_qkv_linear_weight[projection_size+kv_projection_size:].add_(torch.matmul(dv_accum.reshape(l*b, -1).t(), input_chunk)) + grad_qkv_linear_weight[:projection_size].add_( + torch.matmul(dq_accum.reshape(l * b, -1).t(), input_chunk)) + grad_qkv_linear_weight[projection_size:projection_size + kv_projection_size].add_( + torch.matmul(dk_accum.reshape(l * b, -1).t(), input_chunk)) + grad_qkv_linear_weight[projection_size + kv_projection_size:].add_( + torch.matmul(dv_accum.reshape(l * b, -1).t(), input_chunk)) grad_qkv_linear_bias[:projection_size].add_(dq_accum.sum(0).sum(0)) - grad_qkv_linear_bias[projection_size:projection_size+kv_projection_size].add_(dk_accum.sum(0).sum(0)) - grad_qkv_linear_bias[projection_size+kv_projection_size:].add_(dv_accum.sum(0).sum(0)) + grad_qkv_linear_bias[projection_size:projection_size + kv_projection_size].add_(dk_accum.sum(0).sum(0)) + grad_qkv_linear_bias[projection_size + kv_projection_size:].add_(dv_accum.sum(0).sum(0)) grad_layernorm_output[i].add_(torch.matmul(dq_accum, qkv_linear_weight[:projection_size])) - grad_layernorm_output[i].add_(torch.matmul(dk_accum, qkv_linear_weight[projection_size:projection_size+kv_projection_size])) - grad_layernorm_output[i].add_(torch.matmul(dv_accum, qkv_linear_weight[projection_size+kv_projection_size:])) + grad_layernorm_output[i].add_( + torch.matmul(dk_accum, qkv_linear_weight[projection_size:projection_size + kv_projection_size])) + grad_layernorm_output[i].add_( + torch.matmul(dv_accum, qkv_linear_weight[projection_size + kv_projection_size:])) del dq_accum, dk_accum, dv_accum dk_accum = torch.zeros(global_k[i].chunk_shape, dtype=torch.float, device=device) @@ -828,33 +871,34 @@ def backward(ctx, grad_output): global_k[kv_compute_chunk_idx].offload() global_v[kv_compute_chunk_idx].offload() kv_compute_chunk_idx = next_kv_compute_chunk_idx - - return torch.cat(grad_layernorm_output, dim=0).to(dtype), None, None, None, None, None, None, None, None, None, None, grad_qkv_linear_weight.to(dtype), grad_qkv_linear_bias.to(dtype), None, None, None + return torch.cat( + grad_layernorm_output, + dim=0).to(dtype), None, None, None, None, None, None, None, None, None, None, grad_qkv_linear_weight.to( + dtype), grad_qkv_linear_bias.to(dtype), None, None, None class FPDT_Attention(torch.nn.Module): - def __init__( - self, - config, - first_weight, - first_bias, - second_weight, - second_bias, - sequence_process_group, - gather_idx: int = 0, - scatter_idx: int = 2, - return_bias=True, - chunk_size=65536, - enable_offloading=True - ) -> None: + + def __init__(self, + config, + first_weight, + first_bias, + second_weight, + second_bias, + sequence_process_group, + gather_idx: int = 0, + scatter_idx: int = 2, + return_bias=True, + chunk_size=65536, + enable_offloading=True) -> None: super(FPDT_Attention, self).__init__() self.spg = sequence_process_group self.scatter_idx = scatter_idx self.gather_idx = gather_idx self.config = config - + self.projection_size = config.kv_channels * config.num_attention_heads self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads self.kv_projection_size = config.kv_channels * config.num_key_value_heads @@ -869,64 +913,32 @@ def __init__( self.dropout = config.attention_dropout self.chunk_size = chunk_size - self.double_buffer = enable_offloading + self.double_buffer = enable_offloading - def forward(self, + def forward(self, layernorm_output, attention_mask, inference_params, - rotary_pos_emb, + rotary_pos_emb, cpu_offloading=True) -> Tensor: - """ forward - - Arguments: - query (Tensor): query input to the layer - key (Tensor): key input to the layer - value (Tensor): value input to the layer - args: other args - - Returns: - * output (Tensor): context output - """ self.num_chunks_attn = layernorm_output.shape[0] * dist.get_world_size(self.spg) // self.chunk_size if not cpu_offloading: - output = _FPDTGPUAttentionImpl_.apply( - layernorm_output, - attention_mask, - inference_params, - rotary_pos_emb, - self.spg, - self.scatter_idx, - self.gather_idx, - self.hidden_size, - self.projection_size, - self.hidden_size_per_attention_head, - self.kv_projection_size, - self.qkv_linear_weight, - self.qkv_linear_bias, - self.dropout, - self.num_chunks_attn, cpu_offloading) + output = _FPDTGPUAttentionImpl_.apply(layernorm_output, attention_mask, inference_params, rotary_pos_emb, + self.spg, self.scatter_idx, self.gather_idx, self.hidden_size, + self.projection_size, self.hidden_size_per_attention_head, + self.kv_projection_size, self.qkv_linear_weight, + self.qkv_linear_bias, self.dropout, self.num_chunks_attn, + cpu_offloading) else: output = _FPDTGPUOffloadingAttentionImpl_.apply( - layernorm_output, - attention_mask, - inference_params, - rotary_pos_emb, - self.spg, - self.scatter_idx, - self.gather_idx, - self.hidden_size, - self.projection_size, - self.hidden_size_per_attention_head, - self.kv_projection_size, - self.qkv_linear_weight, - self.qkv_linear_bias, - self.dropout, + layernorm_output, attention_mask, inference_params, rotary_pos_emb, self.spg, self.scatter_idx, + self.gather_idx, self.hidden_size, self.projection_size, self.hidden_size_per_attention_head, + self.kv_projection_size, self.qkv_linear_weight, self.qkv_linear_bias, self.dropout, self.num_chunks_attn, cpu_offloading) output = output.flatten(2).permute(1, 0, 2).contiguous() - + output = torch.matmul(output, self.qkv_dense_weight.t()) if not self.reture_bias: output += self.qkv_dense_bias @@ -935,14 +947,15 @@ def forward(self, @torch.jit.script def bias_gelu(x): - return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + @torch.jit.script def bias_gelu_back(g, x): tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) - return ff*g + return ff * g class FPDT_FFN(torch.autograd.Function): @@ -979,7 +992,7 @@ def forward(ctx: Any, x, w1, b1, w2, b2, add_bias, chunk_size): return result.to(x.dtype), b2 if not add_bias else None @staticmethod - def backward(ctx, grad_output, grad_bias): + def backward(ctx, grad_output, grad_bias): x, w1, b1, w2, b2 = ctx.saved_tensors device = ctx.device dtype = ctx.dtype @@ -988,7 +1001,7 @@ def backward(ctx, grad_output, grad_bias): num_chunk = ctx.num_chunk chunk_size = x.shape[0] // num_chunk assert chunk_size * num_chunk == grad_output.shape[0] - + grad_w2 = torch.zeros(w2.shape, device=device, dtype=torch.float) grad_b2 = torch.zeros(b2.shape, device=device, dtype=torch.float) grad_w1 = torch.zeros(w1.shape, device=device, dtype=torch.float) @@ -1000,22 +1013,26 @@ def backward(ctx, grad_output, grad_bias): x_chunk = x[st:ed] before_act = (torch.matmul(x_chunk, w1.t()) + b1) - before_act_2 = before_act ** 2 + before_act_2 = before_act**2 tanh_out = torch.tanh(0.79788456 * before_act * (1 + 0.044715 * before_act_2)) - ff = 0.5 * before_act * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * before_act_2)) + 0.5 * (1 + tanh_out) - grad_w2.add_(torch.matmul(grad_output[st:ed].reshape(-1, grad_output.shape[2]).t(), (before_act * 0.5 * (1 + tanh_out)).reshape(-1, before_act.shape[2]))) + ff = 0.5 * before_act * ((1 - tanh_out * tanh_out) * + (0.79788456 + 0.1070322243 * before_act_2)) + 0.5 * (1 + tanh_out) + grad_w2.add_( + torch.matmul(grad_output[st:ed].reshape(-1, grad_output.shape[2]).t(), + (before_act * 0.5 * (1 + tanh_out)).reshape(-1, before_act.shape[2]))) del before_act, before_act_2, tanh_out grad_inter = torch.matmul(grad_output[st:ed], w2) * ff del ff - - grad_w1.add_(torch.matmul(grad_inter.reshape(-1, grad_inter.shape[2]).t(), x_chunk.reshape(-1, x.shape[2]))) + + grad_w1.add_(torch.matmul( + grad_inter.reshape(-1, grad_inter.shape[2]).t(), x_chunk.reshape(-1, x.shape[2]))) grad_b1.add_(grad_inter.sum(0).sum(0)) x[st:ed].copy_(torch.matmul(grad_inter, w1)) del grad_inter - + if add_bias: grad_b2.add_(grad_output[st:ed].sum(0).sum(0)) @@ -1049,13 +1066,15 @@ def forward(ctx: Any, lm_output, labels, logit_weights, rank, spg_size, spg, num vocab_size = logits_chunk.size(2) # nll softmax = torch.nn.functional.softmax(logits_chunk, dim=-1) - loss_chunk = torch.nn.functional.nll_loss(softmax.log().reshape(-1, vocab_size).contiguous(), labels[st:ed, :].reshape(-1).contiguous(), reduction='none') + loss_chunk = torch.nn.functional.nll_loss(softmax.log().reshape(-1, vocab_size).contiguous(), + labels[st:ed, :].reshape(-1).contiguous(), + reduction='none') loss[:, st:ed] = loss_chunk.reshape(chunk_size, batch_size).t() del logits_chunk ctx.save_for_backward(lm_output.to('cpu'), labels) ctx.logit_weights = logit_weights - + seqlen = local_seq_len * spg_size batch_size = loss.size(0) loss = loss.t().contiguous() @@ -1065,11 +1084,11 @@ def forward(ctx: Any, lm_output, labels, logit_weights, rank, spg_size, spg, num torch.distributed.all_gather_into_tensor(loss_all, loss, group=spg) else: torch.distributed._all_gather_base(loss_all, loss, group=spg) - + return loss_all @staticmethod - def backward(ctx, grad_output): + def backward(ctx, grad_output): lm_output, labels = ctx.saved_tensors logit_weights = ctx.logit_weights device = ctx.device @@ -1080,7 +1099,7 @@ def backward(ctx, grad_output): rank = ctx.rank local_seq_len = ctx.local_seq_len - grad_output = grad_output[rank*local_seq_len:(rank+1)*local_seq_len] + grad_output = grad_output[rank * local_seq_len:(rank + 1) * local_seq_len] grad_lm_output = [None for _ in range(num_chunk)] grad_logit_weights = torch.zeros(logit_weights.shape, device=grad_output.device, dtype=torch.float) for i in range(num_chunk): @@ -1088,25 +1107,27 @@ def backward(ctx, grad_output): ed = st + chunk_size lm_output_chunk = lm_output[st:ed].to(device) logits_chunk = torch.matmul(lm_output_chunk, logit_weights.t()).float() - + # nll softmax = torch.nn.functional.softmax(logits_chunk, dim=-1) vocab_size = logits_chunk.size(2) grad_input = softmax grad_2d = grad_input.reshape(-1, vocab_size).contiguous() - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], - device=device) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=device) grad_2d[arange_1d, labels[st:ed, :].reshape(-1).contiguous()] -= 1 grad_input.mul_(grad_output[:chunk_size, :].unsqueeze(dim=-1)) grad_input = grad_input.to(dtype) - + grad_output = grad_output[chunk_size:].contiguous() grad_lm_output_chunk = torch.matmul(grad_input, logit_weights) grad_lm_output[i] = grad_lm_output_chunk - grad_logit_weights.add_(torch.matmul(grad_input.reshape(-1, grad_input.shape[2]).t(), lm_output_chunk.reshape(-1, lm_output_chunk.shape[2]))) - - return torch.cat(grad_lm_output, dim=0).to(dtype), None, grad_logit_weights.to(dtype), None, None, None, None \ No newline at end of file + grad_logit_weights.add_( + torch.matmul( + grad_input.reshape(-1, grad_input.shape[2]).t(), + lm_output_chunk.reshape(-1, lm_output_chunk.shape[2]))) + + return torch.cat(grad_lm_output, dim=0).to(dtype), None, grad_logit_weights.to(dtype), None, None, None, None diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index ff52f616b934..7e86582df595 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -206,7 +206,14 @@ def layer_sync(self, layer): if self.sp_overlap_comm and hasattr(layer, 'done_event'): self.dafult_stream.wait_event(layer.done_event) - def forward(self, query: Tensor, key: Tensor, value: Tensor, batch_dim_idx: int, rotary_pos_emb=None, *args: Any, **kwargs) -> Tensor: + def forward(self, + query: Tensor, + key: Tensor, + value: Tensor, + batch_dim_idx: int, + rotary_pos_emb=None, + *args: Any, + **kwargs) -> Tensor: """ forward Arguments: From 128286cafe255ea788911f1d39da33b18d522353 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Thu, 10 Oct 2024 19:00:08 -0400 Subject: [PATCH 10/54] fix format and add unit test for fpdt --- deepspeed/sequence/fpdt_layer.py | 15 ++--- .../unit/sequence_parallelism/test_ulysses.py | 57 +++++++++++++++++++ 2 files changed, 63 insertions(+), 9 deletions(-) diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index 611d992d809d..a4efce2e62ce 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -11,7 +11,6 @@ import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator -from packaging import version from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward from einops import rearrange from .layer import single_all_to_all, apply_rotary_pos_emb @@ -419,7 +418,8 @@ def __init__(self, chunk: torch.Tensor, device=None, is_in_use=False): self.device = chunk.device if device is None else device cpu_chunk = torch.empty(chunk.shape, dtype=chunk.dtype, device='cpu', pin_memory=True) - if chunk.is_cuda: + + if get_accelerator().on_accelerator(chunk): cpu_chunk.copy_(chunk, non_blocking=True) else: cpu_chunk = chunk @@ -552,7 +552,7 @@ def forward(ctx: Any, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd v_chunk = single_all_to_all(v_chunk, scatter_idx, gather_idx, 0, spg) - torch.distributed.barrier() + dist.barrier() pos_emb_cos_chunk = pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] pos_emb_sin_chunk = pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] @@ -776,7 +776,7 @@ def backward(ctx, grad_output): if grad_global_attn_output[next_q_compute_chunk_idx] is None: grad_global_attn_output_chunk = single_all_to_all(grad_output[:, :chunk_size].contiguous(), scatter_idx, gather_idx, 0, spg) - torch.distributed.barrier() + dist.barrier() grad_output = grad_output[:, chunk_size:] grad_global_attn_output[next_q_compute_chunk_idx] = SequenceChunk( grad_global_attn_output_chunk, is_in_use=True) @@ -821,7 +821,7 @@ def backward(ctx, grad_output): general_offload_stream.synchronize() compute_stream.wait_stream(general_offload_stream) - torch.distributed.barrier() + dist.barrier() with get_accelerator().stream(compute_stream): input_chunk = layernorm_output[i].get_gpu_chunk().reshape(-1, layernorm_output[i].chunk_shape[-1]) @@ -1080,10 +1080,7 @@ def forward(ctx: Any, lm_output, labels, logit_weights, rank, spg_size, spg, num loss = loss.t().contiguous() loss_all = torch.empty(seqlen, batch_size, dtype=loss.dtype, device=loss.device).contiguous() - if version.parse(torch.__version__) >= version.parse('1.13'): - torch.distributed.all_gather_into_tensor(loss_all, loss, group=spg) - else: - torch.distributed._all_gather_base(loss_all, loss, group=spg) + dist.allgather_fn(loss_all, loss, group=spg) return loss_all diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 915c89e0b00a..3a955726fb22 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -5,11 +5,14 @@ import pytest import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter import deepspeed.comm as dist from deepspeed import initialize from transformers import AutoModel from unit.common import DistributedTest from deepspeed.sequence.layer import _SeqAllToAll +from deepspeed.sequence.fpdt_layer import _FPDTGPUOffloadingAttentionImpl_ from unit.util import skip_on_arch @@ -75,3 +78,57 @@ def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_ # Check outputs are the same as input for i in range(1, len(outputs)): assert torch.allclose(input_tensor, outputs[i]), f"Outputs differ for sequence dim {seq_dims[i]}" + + +@pytest.mark.parametrize("d0", [1, 4]) #batch dimension +@pytest.mark.parametrize("d1", [2048, 4096]) #sequence dimension +@pytest.mark.parametrize("chunk_size", [512, 1024]) #size of chunk +@pytest.mark.parametrize("num_heads", [4, 8]) +@pytest.mark.parametrize("head_dim", [16, 32]) +class TestFPDTAttention(): + + def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, chunk_size: int, head_dim: int, + num_heads: int) -> None: + skip_on_arch(min_arch=8) + model = AutoModel.from_pretrained('bert-base-uncased') + ds_engine, _, _, _ = initialize( + model=model, + config_params={ + "train_batch_size": 8, + "data_parallel_size": 1, + "sequence_parallel_size": 1 + }, + ) + #3D tensor : l, b, d + dim = head_dim * num_heads + input_tensor = torch.randn(d1, d0, dim, device=ds_engine.device) + spg = ds_engine.seq_parallel_group + + qkv_linear_weight = Parameter(torch.empty(dim + 2 * dim, dim, device=ds_engine.device, dtype=torch.half)) + + qkv_linear_bias = Parameter(torch.empty(dim + 2 * dim, device=ds_engine.device, dtype=torch.half)) + + num_chunks_attn = input_tensor.shape[0] * dist.get_world_size(spg) // chunk_size + fpdt_output = _FPDTGPUOffloadingAttentionImpl_.apply(input_tensor, None, None, None, spg, 2, 0, dim, dim, + head_dim, dim, qkv_linear_weight, qkv_linear_bias, 0, + num_chunks_attn, True) + + # baseline + qkv = torch.matmul(input_tensor, qkv_linear_weight.t()) + qkv_linear_bias + q = qkv[:, :, :dim].contiguous().reshape(qkv.shape[0], qkv.shape[1], -1, head_dim).permute(1, 2, 0, + 3).contiguous() + k = qkv[:, :, dim:dim * 2].contiguous().reshape(qkv.shape[0], qkv.shape[1], -1, + head_dim).permute(1, 2, 0, 3).contiguous() + v = qkv[:, :, dim * 2:dim * 3].contiguous().reshape(qkv.shape[0], qkv.shape[1], -1, + head_dim).permute(1, 2, 0, + 3).contiguous() # b, nhead, l, d + + scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(dim, dtype=torch.half)) + + causal_mask = torch.triu(torch.ones(d1, d1), diagonal=1).bool() + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(causal_mask, float('-inf')) + attn_weights = F.softmax(scores, dim=-1) + output = torch.matmul(attn_weights, v) + + assert torch.allclose(fpdt_output, output) From ebea5b03e802a46f601622f0f3ac35de79434d37 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Thu, 10 Oct 2024 19:21:43 -0400 Subject: [PATCH 11/54] add einops --- requirements/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 296398f680cc..edad582307ac 100755 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,3 +8,4 @@ py-cpuinfo pydantic>=2.0.0 torch tqdm +einops \ No newline at end of file From 5c8eec8665c6ba1e314d12779fd5b722be41c5e9 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Thu, 10 Oct 2024 20:59:25 -0400 Subject: [PATCH 12/54] add flashattn --- requirements/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index edad582307ac..a042afaea8e5 100755 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,4 +8,5 @@ py-cpuinfo pydantic>=2.0.0 torch tqdm -einops \ No newline at end of file +einops +flash-attn \ No newline at end of file From 764a572772c0b35f21a57f780fd3cd81b86fb799 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Mon, 14 Oct 2024 09:48:03 -0400 Subject: [PATCH 13/54] add requirements for flash-attn in FPDT --- deepspeed/sequence/fpdt_layer.py | 10 +++++++++- requirements/requirements.txt | 3 +-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index a4efce2e62ce..df69f4e7934e 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -11,7 +11,12 @@ import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator -from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward +try: + from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward +except ImportError: + _flash_attn_forward = None + _flash_attn_backward = None + from einops import rearrange from .layer import single_all_to_all, apply_rotary_pos_emb @@ -894,6 +899,9 @@ def __init__(self, enable_offloading=True) -> None: super(FPDT_Attention, self).__init__() + if _flash_attn_forward is None or _flash_attn_backward is None: + raise ImportError("DeepSpeed FPDT requires flash-attn 2.6.3. Please install it with `pip install flash-attn --no-build-isolation`.") + self.spg = sequence_process_group self.scatter_idx = scatter_idx self.gather_idx = gather_idx diff --git a/requirements/requirements.txt b/requirements/requirements.txt index a042afaea8e5..edad582307ac 100755 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,5 +8,4 @@ py-cpuinfo pydantic>=2.0.0 torch tqdm -einops -flash-attn \ No newline at end of file +einops \ No newline at end of file From 534cb937792d0c9cc1ad14273e1d76a85d30c69a Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 5 Nov 2024 00:04:09 +0000 Subject: [PATCH 14/54] skip test when fa is unavailable --- deepspeed/sequence/fpdt_layer.py | 6 ++++-- tests/unit/sequence_parallelism/test_ulysses.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index df69f4e7934e..f82ef688cd2d 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -900,8 +900,10 @@ def __init__(self, super(FPDT_Attention, self).__init__() if _flash_attn_forward is None or _flash_attn_backward is None: - raise ImportError("DeepSpeed FPDT requires flash-attn 2.6.3. Please install it with `pip install flash-attn --no-build-isolation`.") - + raise ImportError( + "DeepSpeed FPDT requires flash-attn 2.6.3. Please install it with `pip install flash-attn --no-build-isolation`." + ) + self.spg = sequence_process_group self.scatter_idx = scatter_idx self.gather_idx = gather_idx diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 3a955726fb22..5e4cd0f47115 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -90,6 +90,16 @@ class TestFPDTAttention(): def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, chunk_size: int, head_dim: int, num_heads: int) -> None: skip_on_arch(min_arch=8) + + try: + from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward + except ImportError: + _flash_attn_forward = None + _flash_attn_backward = None + + if _flash_attn_forward is None or _flash_attn_backward is None: + pytest.skip("Flash Attention is not available.") + model = AutoModel.from_pretrained('bert-base-uncased') ds_engine, _, _, _ = initialize( model=model, From 972ddda9dd7078e9f6d0d8690ebec3971e082e5a Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 5 Nov 2024 00:05:37 +0000 Subject: [PATCH 15/54] formatting --- deepspeed/sequence/fpdt_layer.py | 4 ++-- requirements/requirements.txt | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index f82ef688cd2d..f6c812e14aff 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -329,13 +329,13 @@ def backward(ctx, grad_output): q_chunk = global_q[q_i] attn_output_chunk = attn_output[q_i] lse_chunk = lse[q_i] - dout = grad_global_attn_output[q_i] + d_out = grad_global_attn_output[q_i] dq_this = torch.zeros(global_q[0].shape, dtype=dtype, device=device) dk_this = torch.zeros(global_k[0].shape, dtype=dtype, device=device) dv_this = torch.zeros(global_v[0].shape, dtype=dtype, device=device) - _flash_attn_backward(dout, + _flash_attn_backward(d_out, q_chunk, k_chunk, v_chunk, diff --git a/requirements/requirements.txt b/requirements/requirements.txt index edad582307ac..1af4c69c5807 100755 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,3 +1,4 @@ +einops hjson msgpack ninja @@ -8,4 +9,3 @@ py-cpuinfo pydantic>=2.0.0 torch tqdm -einops \ No newline at end of file From 37bc6942fbf0523dd020348c45c2a1394a94cd2f Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 5 Nov 2024 00:32:22 +0000 Subject: [PATCH 16/54] add workflow to run a6000 tests --- .github/workflows/nv-flash-attn.yml | 58 +++++++++++++++++++ requirements/requirements-flash_attn.txt | 1 + setup.py | 1 + .../unit/sequence_parallelism/test_ulysses.py | 9 +-- 4 files changed, 65 insertions(+), 4 deletions(-) create mode 100644 .github/workflows/nv-flash-attn.yml create mode 100755 requirements/requirements-flash_attn.txt diff --git a/.github/workflows/nv-flash-attn.yml b/.github/workflows/nv-flash-attn.yml new file mode 100644 index 000000000000..08e57ea6f668 --- /dev/null +++ b/.github/workflows/nv-flash-attn.yml @@ -0,0 +1,58 @@ +name: nv-flash-attn + +on: + workflow_dispatch: + pull_request: + paths: + - 'deepspeed/sequence/**' + - '.github/workflows/nv-flash-attn.yml' + merge_group: + branches: [ master ] +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + unit-tests: + runs-on: [self-hosted, nvidia, a6000] + container: + image: nvcr.io/nvidia/pytorch:23.03-py3 + ports: + - 80 + options: --gpus all --shm-size "8G" + + steps: + - uses: actions/checkout@v4 + + - id: setup-venv + uses: ./.github/workflows/setup-venv + + - name: Install pytorch + run: | + pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/cu121 + python -c "import torch; print('torch:', torch.__version__, torch)" + python -c "import torch; print('CUDA available:', torch.cuda.is_available())" + + - name: Install transformers + run: | + git clone https://github.com/huggingface/transformers + cd transformers + # if needed switch to the last known good SHA until transformers@master is fixed + # git checkout 1cc453d33 + git rev-parse --short HEAD + pip install . + + - name: Install deepspeed + run: | + pip install .[dev,flash_attn] + ds_report + + - name: Python environment + run: | + pip list + + - name: Unit tests + run: | + unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch + cd tests + pytest $PYTEST_OPTS --forked -n 4 unit/sequence_parallelism/test_ulysses.py --torch_ver="2.5" --cuda_ver="12.1" diff --git a/requirements/requirements-flash_attn.txt b/requirements/requirements-flash_attn.txt new file mode 100755 index 000000000000..23b905b5fc30 --- /dev/null +++ b/requirements/requirements-flash_attn.txt @@ -0,0 +1 @@ +pip install flash-attn --no-build-isolation diff --git a/setup.py b/setup.py index e39d8c7e05a3..8c0182273482 100755 --- a/setup.py +++ b/setup.py @@ -91,6 +91,7 @@ def get_env_if_set(key, default: typing.Any = ""): 'inf': fetch_requirements('requirements/requirements-inf.txt'), 'sd': fetch_requirements('requirements/requirements-sd.txt'), 'triton': fetch_requirements('requirements/requirements-triton.txt'), + 'flash_attn': fetch_requirements('requirements/requirements-flash_attn.txt'), } # Only install pynvml on nvidia gpus. diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 5e4cd0f47115..96616355b559 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -18,12 +18,12 @@ #Use mesh device to create data and sequence parallel group class TestUlyssesUtils(DistributedTest): - world_size = 4 + world_size = 2 def test_mesh_device_creation(self) -> None: skip_on_arch(min_arch=8) model = AutoModel.from_pretrained('bert-base-uncased') - sp_size = 2 + sp_size = 1 dp_size = 2 ds_engine, _, _, _ = initialize( model=model, @@ -46,7 +46,7 @@ def test_mesh_device_creation(self) -> None: @pytest.mark.parametrize("num_heads", [4, 8]) @pytest.mark.parametrize("head_dim", [16, 32]) class TestUlyssesAll2All(DistributedTest): - world_size = 4 + world_size = 2 def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_heads: int) -> None: skip_on_arch(min_arch=8) @@ -85,7 +85,8 @@ def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_ @pytest.mark.parametrize("chunk_size", [512, 1024]) #size of chunk @pytest.mark.parametrize("num_heads", [4, 8]) @pytest.mark.parametrize("head_dim", [16, 32]) -class TestFPDTAttention(): +class TestFPDTAttention(DistributedTest): + world_size = 1 def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, chunk_size: int, head_dim: int, num_heads: int) -> None: From ac7baf676843c1a75da1c376172f461f471ee337 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 5 Nov 2024 18:07:30 +0000 Subject: [PATCH 17/54] revert world sizes for tests --- tests/unit/sequence_parallelism/test_ulysses.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index d15dea6d7ae3..183cef4f0ec8 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -21,12 +21,12 @@ class TestUlyssesUtils(DistributedTest): - world_size = 2 + world_size = 4 def test_mesh_device_creation(self) -> None: skip_on_arch(min_arch=8) model = AutoModel.from_pretrained('bert-base-uncased') - sp_size = 1 + sp_size = 2 dp_size = 2 ds_engine, _, _, _ = initialize( model=model, @@ -49,7 +49,7 @@ def test_mesh_device_creation(self) -> None: @pytest.mark.parametrize("num_heads", [4, 8]) @pytest.mark.parametrize("head_dim", [16, 32]) class TestUlyssesAll2All(DistributedTest): - world_size = 2 + world_size = 4 def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_heads: int) -> None: skip_on_arch(min_arch=8) From 893552945f66345241a3531824938985771d49ab Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 5 Nov 2024 18:49:02 +0000 Subject: [PATCH 18/54] update workflow --- .github/workflows/nv-flash-attn.yml | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/.github/workflows/nv-flash-attn.yml b/.github/workflows/nv-flash-attn.yml index 08e57ea6f668..a689f64d9430 100644 --- a/.github/workflows/nv-flash-attn.yml +++ b/.github/workflows/nv-flash-attn.yml @@ -6,8 +6,6 @@ on: paths: - 'deepspeed/sequence/**' - '.github/workflows/nv-flash-attn.yml' - merge_group: - branches: [ master ] concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true @@ -24,35 +22,28 @@ jobs: steps: - uses: actions/checkout@v4 - - id: setup-venv - uses: ./.github/workflows/setup-venv - - - name: Install pytorch + - name: Check container state run: | - pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/cu121 + ldd --version + nvcc --version + nvidia-smi python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" - - name: Install transformers run: | - git clone https://github.com/huggingface/transformers + git clone --depth=1 https://github.com/huggingface/transformers cd transformers - # if needed switch to the last known good SHA until transformers@master is fixed - # git checkout 1cc453d33 git rev-parse --short HEAD - pip install . - + python -m pip install . - name: Install deepspeed run: | - pip install .[dev,flash_attn] + python -m pip install .[dev,flash_attn] ds_report - - name: Python environment run: | - pip list - + python -m pip list - name: Unit tests run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - pytest $PYTEST_OPTS --forked -n 4 unit/sequence_parallelism/test_ulysses.py --torch_ver="2.5" --cuda_ver="12.1" + python -m pytest --color=yes --durations=0 --verbose -rF unit/sequence_parallelism/test_ulysses.py --torch_ver="2.3" --cuda_ver="12" From edd2e0514e33d7dbc27c3a1cba26720d85ee2a82 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 5 Nov 2024 19:22:18 +0000 Subject: [PATCH 19/54] update image version --- .github/workflows/nv-flash-attn.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/nv-flash-attn.yml b/.github/workflows/nv-flash-attn.yml index a689f64d9430..1bad45fd93f3 100644 --- a/.github/workflows/nv-flash-attn.yml +++ b/.github/workflows/nv-flash-attn.yml @@ -14,7 +14,7 @@ jobs: unit-tests: runs-on: [self-hosted, nvidia, a6000] container: - image: nvcr.io/nvidia/pytorch:23.03-py3 + image: nvcr.io/nvidia/pytorch:24.03-py3 ports: - 80 options: --gpus all --shm-size "8G" From 464d117fe90660ee5db72fc82970e399581c11ae Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 5 Nov 2024 19:23:14 +0000 Subject: [PATCH 20/54] remove --no-build-isolation --- requirements/requirements-flash_attn.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements-flash_attn.txt b/requirements/requirements-flash_attn.txt index 23b905b5fc30..0d623275746c 100755 --- a/requirements/requirements-flash_attn.txt +++ b/requirements/requirements-flash_attn.txt @@ -1 +1 @@ -pip install flash-attn --no-build-isolation +pip install flash-attn From 7389f6635746cfeea01b0e8a7c9203fddd1ef9c7 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 5 Nov 2024 19:28:49 +0000 Subject: [PATCH 21/54] remove requirements file for flash-attn --- .github/workflows/nv-flash-attn.yml | 5 ++++- requirements/requirements-flash_attn.txt | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) delete mode 100755 requirements/requirements-flash_attn.txt diff --git a/.github/workflows/nv-flash-attn.yml b/.github/workflows/nv-flash-attn.yml index 1bad45fd93f3..de06ae885d75 100644 --- a/.github/workflows/nv-flash-attn.yml +++ b/.github/workflows/nv-flash-attn.yml @@ -37,8 +37,11 @@ jobs: python -m pip install . - name: Install deepspeed run: | - python -m pip install .[dev,flash_attn] + python -m pip install .[dev] ds_report + - name: Install FlashAttention + run: | + python -m pip flash-attn - name: Python environment run: | python -m pip list diff --git a/requirements/requirements-flash_attn.txt b/requirements/requirements-flash_attn.txt deleted file mode 100755 index 0d623275746c..000000000000 --- a/requirements/requirements-flash_attn.txt +++ /dev/null @@ -1 +0,0 @@ -pip install flash-attn From 5f859be985b7744e035a56d5852b68c55fca3db8 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 5 Nov 2024 19:31:38 +0000 Subject: [PATCH 22/54] remove flash-attn requirements from setup.py --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index 8c0182273482..e39d8c7e05a3 100755 --- a/setup.py +++ b/setup.py @@ -91,7 +91,6 @@ def get_env_if_set(key, default: typing.Any = ""): 'inf': fetch_requirements('requirements/requirements-inf.txt'), 'sd': fetch_requirements('requirements/requirements-sd.txt'), 'triton': fetch_requirements('requirements/requirements-triton.txt'), - 'flash_attn': fetch_requirements('requirements/requirements-flash_attn.txt'), } # Only install pynvml on nvidia gpus. From 56cb64761f5993f132f5ba7b0e0098af5e654f81 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 5 Nov 2024 19:33:51 +0000 Subject: [PATCH 23/54] fix pip command --- .github/workflows/nv-flash-attn.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/nv-flash-attn.yml b/.github/workflows/nv-flash-attn.yml index de06ae885d75..a7455ff4e5d8 100644 --- a/.github/workflows/nv-flash-attn.yml +++ b/.github/workflows/nv-flash-attn.yml @@ -41,7 +41,7 @@ jobs: ds_report - name: Install FlashAttention run: | - python -m pip flash-attn + python -m pip install flash-attn - name: Python environment run: | python -m pip list From 164f459ec589d89f961368cad3cd40afb9bf626e Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Tue, 5 Nov 2024 14:58:42 -0500 Subject: [PATCH 24/54] modify unit test for fpdt --- deepspeed/sequence/fpdt_layer.py | 36 ++++++++++++------- .../unit/sequence_parallelism/test_ulysses.py | 2 +- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index f6c812e14aff..fd925e7b1234 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -486,9 +486,13 @@ def forward(ctx: Any, do_save = layernorm_output.requires_grad - pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3) - ctx.pos_emb_cos = pos_emb_cos - ctx.pos_emb_sin = pos_emb_sin + if rotary_pos_emb is not None: + pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3) + ctx.pos_emb_cos = pos_emb_cos + ctx.pos_emb_sin = pos_emb_sin + else: + ctx.pos_emb_cos = None + ctx.pos_emb_sin = None with torch.no_grad(): per_gpu_seq_len = layernorm_output.shape[0] chunk_size = per_gpu_seq_len // num_chunks @@ -559,11 +563,12 @@ def forward(ctx: Any, dist.barrier() - pos_emb_cos_chunk = pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] - pos_emb_sin_chunk = pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] + if ctx.pos_emb_cos is not None: + pos_emb_cos_chunk = pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] + pos_emb_sin_chunk = pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] - q_chunk = apply_rotary_pos_emb(q_chunk, pos_emb_cos_chunk, pos_emb_sin_chunk) - k_chunk = apply_rotary_pos_emb(k_chunk, pos_emb_cos_chunk, pos_emb_sin_chunk) + q_chunk = apply_rotary_pos_emb(q_chunk, pos_emb_cos_chunk, pos_emb_sin_chunk) + k_chunk = apply_rotary_pos_emb(k_chunk, pos_emb_cos_chunk, pos_emb_sin_chunk) compute_stream.wait_stream(offload_stream) compute_stream.synchronize() @@ -812,12 +817,17 @@ def backward(ctx, grad_output): compute_stream.synchronize() dk_seq_len = dk_accum.shape[1] - dq_accum = apply_rotary_pos_emb_backward(dq[kv_compute_chunk_idx].get_gpu_chunk().to(dtype), - ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], - ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) - dk_accum = apply_rotary_pos_emb_backward(dk_accum.to(dtype), - ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], - ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) + + if ctx.pos_emb_cos is not None: + dq_accum = apply_rotary_pos_emb_backward(dq[kv_compute_chunk_idx].get_gpu_chunk().to(dtype), + ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], + ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) + dk_accum = apply_rotary_pos_emb_backward(dk_accum.to(dtype), + ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], + ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) + else: + dq_accum = dq[kv_compute_chunk_idx].get_gpu_chunk().to(dtype) + dk_accum = dk_accum.to(dtype) dv_accum = dv_accum.to(dtype) dq_accum = single_all_to_all(dq_accum.contiguous(), gather_idx, scatter_idx, 0, spg) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 183cef4f0ec8..84d005436a9e 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -195,7 +195,7 @@ def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, ch #3D tensor : l, b, d dim = head_dim * num_heads input_tensor = torch.randn(d1, d0, dim, device=ds_engine.device) - spg = ds_engine.seq_parallel_group + spg = ds_engine.data_parallel_group qkv_linear_weight = Parameter(torch.empty(dim + 2 * dim, dim, device=ds_engine.device, dtype=torch.half)) From 3eb816dbb520441090e6575ec7aaed1d5afa684c Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Tue, 5 Nov 2024 15:01:24 -0500 Subject: [PATCH 25/54] modify unit test for fpdt --- tests/unit/sequence_parallelism/test_ulysses.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 84d005436a9e..2834e10676c2 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -223,5 +223,7 @@ def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, ch scores = scores.masked_fill(causal_mask, float('-inf')) attn_weights = F.softmax(scores, dim=-1) output = torch.matmul(attn_weights, v) - - assert torch.allclose(fpdt_output, output) + + if not torch.allclose(fpdt_output, output): + max_abs_diff = torch.max(torch.abs(tensor_a - tensor_b)) + print("Max absolute difference:", max_abs_diff.item()) From 2ae68dc1f8f9d71ad988c8f4dc3a25b1962b7ce2 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Tue, 5 Nov 2024 15:08:34 -0500 Subject: [PATCH 26/54] modify unit test for fpdt --- deepspeed/sequence/fpdt_layer.py | 10 +++++----- tests/unit/sequence_parallelism/test_ulysses.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index fd925e7b1234..4c4f0ea2c6c4 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -817,14 +817,14 @@ def backward(ctx, grad_output): compute_stream.synchronize() dk_seq_len = dk_accum.shape[1] - + if ctx.pos_emb_cos is not None: dq_accum = apply_rotary_pos_emb_backward(dq[kv_compute_chunk_idx].get_gpu_chunk().to(dtype), - ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], - ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) + ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], + ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) dk_accum = apply_rotary_pos_emb_backward(dk_accum.to(dtype), - ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], - ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) + ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], + ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) else: dq_accum = dq[kv_compute_chunk_idx].get_gpu_chunk().to(dtype) dk_accum = dk_accum.to(dtype) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 2834e10676c2..7d40ca178554 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -194,7 +194,7 @@ def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, ch ) #3D tensor : l, b, d dim = head_dim * num_heads - input_tensor = torch.randn(d1, d0, dim, device=ds_engine.device) + input_tensor = torch.randn(d1, d0, dim, device=ds_engine.device, dtype=torch.half) spg = ds_engine.data_parallel_group qkv_linear_weight = Parameter(torch.empty(dim + 2 * dim, dim, device=ds_engine.device, dtype=torch.half)) @@ -223,7 +223,7 @@ def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, ch scores = scores.masked_fill(causal_mask, float('-inf')) attn_weights = F.softmax(scores, dim=-1) output = torch.matmul(attn_weights, v) - + if not torch.allclose(fpdt_output, output): max_abs_diff = torch.max(torch.abs(tensor_a - tensor_b)) print("Max absolute difference:", max_abs_diff.item()) From b1b2688073b544e04a6a0d9453ab418acefad137 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Tue, 5 Nov 2024 15:11:36 -0500 Subject: [PATCH 27/54] modify unit test for fpdt --- tests/unit/sequence_parallelism/test_ulysses.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 7d40ca178554..b2dd7d2edf46 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -162,11 +162,11 @@ def seq_batch_heads_hash(d0, d1, h, offset_d0=0, offset_d1=0, offset_h=0): outputs[i]), f"[{dist.get_rank()}]Outputs differ for sequence dim {seq_dims[i]}" -@pytest.mark.parametrize("d0", [1, 4]) #batch dimension -@pytest.mark.parametrize("d1", [2048, 4096]) #sequence dimension -@pytest.mark.parametrize("chunk_size", [512, 1024]) #size of chunk -@pytest.mark.parametrize("num_heads", [4, 8]) -@pytest.mark.parametrize("head_dim", [16, 32]) +@pytest.mark.parametrize("d0", [4]) #batch dimension +@pytest.mark.parametrize("d1", [2048]) #sequence dimension +@pytest.mark.parametrize("chunk_size", [512]) #size of chunk +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("head_dim", [32]) class TestFPDTAttention(DistributedTest): world_size = 1 From 67aa3dfbfd99d49465940a3ceaf0044f3097eccc Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Tue, 5 Nov 2024 15:13:02 -0500 Subject: [PATCH 28/54] modify unit test for fpdt --- tests/unit/sequence_parallelism/test_ulysses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index b2dd7d2edf46..b6f3ec83d98c 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -225,5 +225,5 @@ def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, ch output = torch.matmul(attn_weights, v) if not torch.allclose(fpdt_output, output): - max_abs_diff = torch.max(torch.abs(tensor_a - tensor_b)) + max_abs_diff = torch.max(torch.abs(fpdt_output - output)) print("Max absolute difference:", max_abs_diff.item()) From 42461d27a16b9aab715f51f2e4b8ed28aa447e69 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Tue, 5 Nov 2024 15:21:29 -0500 Subject: [PATCH 29/54] modify unit test for fpdt --- deepspeed/sequence/fpdt_layer.py | 86 +++++++++++++++++++++----------- 1 file changed, 58 insertions(+), 28 deletions(-) diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index 4c4f0ea2c6c4..5b9388336000 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -584,17 +584,29 @@ def forward(ctx: Any, for k_i in range(len(global_k)): causal_chunk = i == k_i with get_accelerator().stream(compute_stream): - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( - global_q[q_compute_chunk_idx].get_gpu_chunk(), - global_k[kv_compute_chunk_idx].get_gpu_chunk(), - global_v[kv_compute_chunk_idx].get_gpu_chunk(), - ctx.dropout_p, - ctx.softmax_scale, - causal=causal_chunk, - window_size=ctx.window_size, - softcap=0.0, - alibi_slopes=ctx.alibi_slopes, - return_softmax=False) + try: + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( + global_q[q_compute_chunk_idx].get_gpu_chunk(), + global_k[kv_compute_chunk_idx].get_gpu_chunk(), + global_v[kv_compute_chunk_idx].get_gpu_chunk(), + ctx.dropout_p, + ctx.softmax_scale, + causal=causal_chunk, + window_size=ctx.window_size, + softcap=0.0, + alibi_slopes=ctx.alibi_slopes, + return_softmax=False) + except TypeError: + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( + global_q[q_compute_chunk_idx].get_gpu_chunk(), + global_k[kv_compute_chunk_idx].get_gpu_chunk(), + global_v[kv_compute_chunk_idx].get_gpu_chunk(), + ctx.dropout_p, + ctx.softmax_scale, + causal=causal_chunk, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + return_softmax=False) cur_attn_output, cur_attn_lse = update_out_and_lse(cur_attn_output, cur_attn_lse, block_out, block_lse) @@ -743,23 +755,41 @@ def backward(ctx, grad_output): dv_this = torch.zeros(global_v[0].chunk_shape, dtype=dtype, device=device) with get_accelerator().stream(compute_stream): - _flash_attn_backward(grad_global_attn_output[q_compute_chunk_idx].get_gpu_chunk(), - global_q[q_compute_chunk_idx].get_gpu_chunk(), - global_k[kv_compute_chunk_idx].get_gpu_chunk(), - global_v[kv_compute_chunk_idx].get_gpu_chunk(), - attn_output[q_compute_chunk_idx].get_gpu_chunk(), - lse[q_compute_chunk_idx].get_gpu_chunk(), - dq_this, - dk_this, - dv_this, - dropout_p, - softmax_scale, - causal_chunk, - window_size, - softcap=0.0, - alibi_slopes=alibi_slopes, - deterministic=False, - rng_state=None) + try: + _flash_attn_backward(grad_global_attn_output[q_compute_chunk_idx].get_gpu_chunk(), + global_q[q_compute_chunk_idx].get_gpu_chunk(), + global_k[kv_compute_chunk_idx].get_gpu_chunk(), + global_v[kv_compute_chunk_idx].get_gpu_chunk(), + attn_output[q_compute_chunk_idx].get_gpu_chunk(), + lse[q_compute_chunk_idx].get_gpu_chunk(), + dq_this, + dk_this, + dv_this, + dropout_p, + softmax_scale, + causal_chunk, + window_size, + softcap=0.0, + alibi_slopes=alibi_slopes, + deterministic=False, + rng_state=None) + except TypeError: + _flash_attn_backward(grad_global_attn_output[q_compute_chunk_idx].get_gpu_chunk(), + global_q[q_compute_chunk_idx].get_gpu_chunk(), + global_k[kv_compute_chunk_idx].get_gpu_chunk(), + global_v[kv_compute_chunk_idx].get_gpu_chunk(), + attn_output[q_compute_chunk_idx].get_gpu_chunk(), + lse[q_compute_chunk_idx].get_gpu_chunk(), + dq_this, + dk_this, + dv_this, + dropout_p, + softmax_scale, + causal_chunk, + window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + rng_state=None) if i != (len(global_k) - 1): if q_i != (len(global_q) - 1): From d637d604cc389a203f69a0051576a2f427ae946b Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Tue, 5 Nov 2024 15:24:57 -0500 Subject: [PATCH 30/54] modify unit test for fpdt --- tests/unit/sequence_parallelism/test_ulysses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index b6f3ec83d98c..5feb56ffea7a 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -218,7 +218,7 @@ def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, ch scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(dim, dtype=torch.half)) - causal_mask = torch.triu(torch.ones(d1, d1), diagonal=1).bool() + causal_mask = torch.triu(torch.ones(d1, d1), diagonal=1).bool().cuda() causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) scores = scores.masked_fill(causal_mask, float('-inf')) attn_weights = F.softmax(scores, dim=-1) From 907c79d0559342f86757124d39514f3ca411affc Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Tue, 5 Nov 2024 15:32:04 -0500 Subject: [PATCH 31/54] modify unit test for fpdt --- tests/unit/sequence_parallelism/test_ulysses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 5feb56ffea7a..6bd89e392d6e 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -222,7 +222,7 @@ def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, ch causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) scores = scores.masked_fill(causal_mask, float('-inf')) attn_weights = F.softmax(scores, dim=-1) - output = torch.matmul(attn_weights, v) + output = torch.matmul(attn_weights, v).permute(0, 2, 1, 3) if not torch.allclose(fpdt_output, output): max_abs_diff = torch.max(torch.abs(fpdt_output - output)) From 8f5d039721e52fa20b0422b64dd3b4bbd5e4e645 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Tue, 5 Nov 2024 15:36:42 -0500 Subject: [PATCH 32/54] modify unit test for fpdt --- tests/unit/sequence_parallelism/test_ulysses.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 6bd89e392d6e..2bc895cd9080 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -218,12 +218,10 @@ def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, ch scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(dim, dtype=torch.half)) - causal_mask = torch.triu(torch.ones(d1, d1), diagonal=1).bool().cuda() + causal_mask = torch.triu(torch.ones(d1, d1, device=ds_engine.device), diagonal=1).bool() causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) scores = scores.masked_fill(causal_mask, float('-inf')) attn_weights = F.softmax(scores, dim=-1) output = torch.matmul(attn_weights, v).permute(0, 2, 1, 3) - if not torch.allclose(fpdt_output, output): - max_abs_diff = torch.max(torch.abs(fpdt_output - output)) - print("Max absolute difference:", max_abs_diff.item()) + assert torch.allclose(fpdt_output, output), f"{torch.max(torch.abs(fpdt_output - output))}" From 02c2fbf6a559a6128379c444f67288f95d8283af Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Tue, 5 Nov 2024 15:42:00 -0500 Subject: [PATCH 33/54] modify unit test for fpdt --- tests/unit/sequence_parallelism/test_ulysses.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 2bc895cd9080..880a71d6b112 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -162,9 +162,9 @@ def seq_batch_heads_hash(d0, d1, h, offset_d0=0, offset_d1=0, offset_h=0): outputs[i]), f"[{dist.get_rank()}]Outputs differ for sequence dim {seq_dims[i]}" -@pytest.mark.parametrize("d0", [4]) #batch dimension -@pytest.mark.parametrize("d1", [2048]) #sequence dimension -@pytest.mark.parametrize("chunk_size", [512]) #size of chunk +@pytest.mark.parametrize("d0", [4, 1]) #batch dimension +@pytest.mark.parametrize("d1", [2048, 8192]) #sequence dimension +@pytest.mark.parametrize("chunk_size", [512, 2048]) #size of chunk @pytest.mark.parametrize("num_heads", [8]) @pytest.mark.parametrize("head_dim", [32]) class TestFPDTAttention(DistributedTest): From f570213ccdc0cb4fe71bbc353d2032f75fb38dfb Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Tue, 5 Nov 2024 15:49:42 -0500 Subject: [PATCH 34/54] modify unit test for fpdt --- deepspeed/sequence/fpdt_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index 5b9388336000..578f7fcb88c8 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -434,7 +434,7 @@ def __init__(self, chunk: torch.Tensor, device=None, is_in_use=False): self.gpu_chunk = chunk if is_in_use else None def load_to_gpu(self): - assert self.gpu_chunk is None + # assert self.gpu_chunk is None if self.gpu_chunk is not None: pass else: From 5b8c419f6344ec9b1701a2128b5a49a304ce38ac Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Tue, 5 Nov 2024 15:59:20 -0500 Subject: [PATCH 35/54] add condition for using fpdt offloading --- deepspeed/sequence/fpdt_layer.py | 48 ++++++++++++------- .../unit/sequence_parallelism/test_ulysses.py | 2 +- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index 578f7fcb88c8..612986535f5f 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -152,9 +152,13 @@ def forward(ctx: Any, do_save = layernorm_output.requires_grad - pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3) - ctx.pos_emb_cos = pos_emb_cos - ctx.pos_emb_sin = pos_emb_sin + if rotary_pos_emb is not None: + pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3) + ctx.pos_emb_cos = pos_emb_cos + ctx.pos_emb_sin = pos_emb_sin + else: + ctx.pos_emb_cos = None + ctx.pos_emb_sin = None with torch.no_grad(): per_gpu_seq_len = layernorm_output.shape[0] @@ -200,18 +204,20 @@ def forward(ctx: Any, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd q_chunk = single_all_to_all(q_chunk, scatter_idx, gather_idx, 0, spg) global_q_chunk_len = q_chunk.shape[1] - q_chunk = apply_rotary_pos_emb(q_chunk, - pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], - pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) + if rotary_pos_emb is not None: + q_chunk = apply_rotary_pos_emb(q_chunk, + pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], + pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) global_q.append(q_chunk) k_chunk = qkv_chunk[:, :, projection_size:projection_size + kv_projection_size].contiguous().reshape( qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd k_chunk = single_all_to_all(k_chunk, scatter_idx, gather_idx, 0, spg) - k_chunk = apply_rotary_pos_emb(k_chunk, - pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], - pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) + if rotary_pos_emb is not None: + k_chunk = apply_rotary_pos_emb(k_chunk, + pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], + pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) global_k.append(k_chunk) v_chunk = qkv_chunk[:, :, projection_size + kv_projection_size:].contiguous().reshape( @@ -358,9 +364,13 @@ def backward(ctx, grad_output): dv[i].add_(dv_this.to(torch.float)) dk_seq_len = dk[i].shape[1] - dk[i] = apply_rotary_pos_emb_backward(dk[i].to(dtype), - ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], - ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) + + if ctx.pos_emb_cos is not None: + dk[i] = apply_rotary_pos_emb_backward(dk[i].to(dtype), + ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], + ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) + else: + dk[i] = dk[i].to(dtype) dv[i] = dv[i].to(dtype) dk[i] = single_all_to_all(dk[i].contiguous(), gather_idx, scatter_idx, 0, spg) dv[i] = single_all_to_all(dv[i].contiguous(), gather_idx, scatter_idx, 0, spg) @@ -390,10 +400,12 @@ def backward(ctx, grad_output): for i in range(num_chunks): dq_seq_len = dq[i].shape[1] - dq[i] = apply_rotary_pos_emb_backward(dq[i].to(dtype), - ctx.pos_emb_cos[:, dq_seq_len * i:dq_seq_len * (i + 1)], - ctx.pos_emb_sin[:, dq_seq_len * i:dq_seq_len * (i + 1)]) - + if ctx.pos_emb_cos is not None: + dq[i] = apply_rotary_pos_emb_backward(dq[i].to(dtype), + ctx.pos_emb_cos[:, dq_seq_len * i:dq_seq_len * (i + 1)], + ctx.pos_emb_sin[:, dq_seq_len * i:dq_seq_len * (i + 1)]) + else: + dq[i] = dq[i].to(dtype) dq[i] = single_all_to_all(dq[i].to(dtype).contiguous(), gather_idx, scatter_idx, 0, spg) input_chunk = layernorm_output[:input_chunk_size].reshape(-1, layernorm_output.shape[-1]) @@ -434,7 +446,7 @@ def __init__(self, chunk: torch.Tensor, device=None, is_in_use=False): self.gpu_chunk = chunk if is_in_use else None def load_to_gpu(self): - # assert self.gpu_chunk is None + assert self.gpu_chunk is None if self.gpu_chunk is not None: pass else: @@ -973,7 +985,7 @@ def forward(self, cpu_offloading=True) -> Tensor: self.num_chunks_attn = layernorm_output.shape[0] * dist.get_world_size(self.spg) // self.chunk_size - if not cpu_offloading: + if not cpu_offloading or self.num_chunks_attn == 1: output = _FPDTGPUAttentionImpl_.apply(layernorm_output, attention_mask, inference_params, rotary_pos_emb, self.spg, self.scatter_idx, self.gather_idx, self.hidden_size, self.projection_size, self.hidden_size_per_attention_head, diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 880a71d6b112..5a4b41f8636e 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -164,7 +164,7 @@ def seq_batch_heads_hash(d0, d1, h, offset_d0=0, offset_d1=0, offset_h=0): @pytest.mark.parametrize("d0", [4, 1]) #batch dimension @pytest.mark.parametrize("d1", [2048, 8192]) #sequence dimension -@pytest.mark.parametrize("chunk_size", [512, 2048]) #size of chunk +@pytest.mark.parametrize("chunk_size", [512, 1024]) #size of chunk @pytest.mark.parametrize("num_heads", [8]) @pytest.mark.parametrize("head_dim", [32]) class TestFPDTAttention(DistributedTest): From bd090c84a0a554f15cad5aafbc76f49272b82a4c Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Tue, 5 Nov 2024 16:00:10 -0500 Subject: [PATCH 36/54] add condition for using fpdt offloading --- deepspeed/sequence/fpdt_layer.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index 612986535f5f..1f0cc3341c2e 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -206,8 +206,8 @@ def forward(ctx: Any, global_q_chunk_len = q_chunk.shape[1] if rotary_pos_emb is not None: q_chunk = apply_rotary_pos_emb(q_chunk, - pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], - pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) + pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], + pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) global_q.append(q_chunk) k_chunk = qkv_chunk[:, :, projection_size:projection_size + kv_projection_size].contiguous().reshape( @@ -216,8 +216,8 @@ def forward(ctx: Any, k_chunk = single_all_to_all(k_chunk, scatter_idx, gather_idx, 0, spg) if rotary_pos_emb is not None: k_chunk = apply_rotary_pos_emb(k_chunk, - pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], - pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) + pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], + pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) global_k.append(k_chunk) v_chunk = qkv_chunk[:, :, projection_size + kv_projection_size:].contiguous().reshape( @@ -364,11 +364,11 @@ def backward(ctx, grad_output): dv[i].add_(dv_this.to(torch.float)) dk_seq_len = dk[i].shape[1] - + if ctx.pos_emb_cos is not None: dk[i] = apply_rotary_pos_emb_backward(dk[i].to(dtype), - ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], - ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) + ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], + ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) else: dk[i] = dk[i].to(dtype) dv[i] = dv[i].to(dtype) @@ -402,8 +402,8 @@ def backward(ctx, grad_output): dq_seq_len = dq[i].shape[1] if ctx.pos_emb_cos is not None: dq[i] = apply_rotary_pos_emb_backward(dq[i].to(dtype), - ctx.pos_emb_cos[:, dq_seq_len * i:dq_seq_len * (i + 1)], - ctx.pos_emb_sin[:, dq_seq_len * i:dq_seq_len * (i + 1)]) + ctx.pos_emb_cos[:, dq_seq_len * i:dq_seq_len * (i + 1)], + ctx.pos_emb_sin[:, dq_seq_len * i:dq_seq_len * (i + 1)]) else: dq[i] = dq[i].to(dtype) dq[i] = single_all_to_all(dq[i].to(dtype).contiguous(), gather_idx, scatter_idx, 0, spg) From e48e85baea6389ac818a7de1af713fbd28b927fb Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Tue, 5 Nov 2024 16:11:08 -0500 Subject: [PATCH 37/54] add flash-attn version check --- deepspeed/sequence/fpdt_layer.py | 96 +++++++++++++++++++++----------- 1 file changed, 64 insertions(+), 32 deletions(-) diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index 1f0cc3341c2e..5d574aa153b8 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -7,12 +7,14 @@ from typing import Optional, Any, Tuple from torch import Tensor - +from packaging import version import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator try: + import flash_attn from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward + flash_attn_version = version.parse(flash_attn.__version__) except ImportError: _flash_attn_forward = None _flash_attn_backward = None @@ -228,16 +230,28 @@ def forward(ctx: Any, for k_i in range(len(global_k)): causal_chunk = i == k_i - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(global_q[i], - global_k[k_i], - global_v[k_i], - ctx.dropout_p, - ctx.softmax_scale, - causal=causal_chunk, - window_size=ctx.window_size, - softcap=0.0, - alibi_slopes=ctx.alibi_slopes, - return_softmax=False) + if flash_attn_version >= version.parse("2.6.0"): + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(global_q[i], + global_k[k_i], + global_v[k_i], + ctx.dropout_p, + ctx.softmax_scale, + causal=causal_chunk, + window_size=ctx.window_size, + softcap=0.0, + alibi_slopes=ctx.alibi_slopes, + return_softmax=False) + else: + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(global_q[i], + global_k[k_i], + global_v[k_i], + ctx.dropout_p, + ctx.softmax_scale, + causal=causal_chunk, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + return_softmax=False) + global_o[i], global_lse[i] = update_out_and_lse(global_o[i], global_lse[i], block_out, block_lse) global_o[i] = global_o[i].to(q_chunk.dtype) @@ -341,23 +355,41 @@ def backward(ctx, grad_output): dk_this = torch.zeros(global_k[0].shape, dtype=dtype, device=device) dv_this = torch.zeros(global_v[0].shape, dtype=dtype, device=device) - _flash_attn_backward(d_out, - q_chunk, - k_chunk, - v_chunk, - attn_output_chunk, - lse_chunk, - dq_this, - dk_this, - dv_this, - dropout_p, - softmax_scale, - causal_chunk, - window_size, - softcap=0.0, - alibi_slopes=alibi_slopes, - deterministic=False, - rng_state=None) + if flash_attn_version >= version.parse("2.6.0"): + _flash_attn_backward(d_out, + q_chunk, + k_chunk, + v_chunk, + attn_output_chunk, + lse_chunk, + dq_this, + dk_this, + dv_this, + dropout_p, + softmax_scale, + causal_chunk, + window_size, + softcap=0.0, + alibi_slopes=alibi_slopes, + deterministic=False, + rng_state=None) + else: + _flash_attn_backward(d_out, + q_chunk, + k_chunk, + v_chunk, + attn_output_chunk, + lse_chunk, + dq_this, + dk_this, + dv_this, + dropout_p, + softmax_scale, + causal_chunk, + window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + rng_state=None) dq[q_i].add_(dq_this.to(torch.float)) dk[i].add_(dk_this.to(torch.float)) @@ -596,7 +628,7 @@ def forward(ctx: Any, for k_i in range(len(global_k)): causal_chunk = i == k_i with get_accelerator().stream(compute_stream): - try: + if flash_attn_version >= version.parse("2.6.0"): block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( global_q[q_compute_chunk_idx].get_gpu_chunk(), global_k[kv_compute_chunk_idx].get_gpu_chunk(), @@ -608,7 +640,7 @@ def forward(ctx: Any, softcap=0.0, alibi_slopes=ctx.alibi_slopes, return_softmax=False) - except TypeError: + else: block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( global_q[q_compute_chunk_idx].get_gpu_chunk(), global_k[kv_compute_chunk_idx].get_gpu_chunk(), @@ -767,7 +799,7 @@ def backward(ctx, grad_output): dv_this = torch.zeros(global_v[0].chunk_shape, dtype=dtype, device=device) with get_accelerator().stream(compute_stream): - try: + if flash_attn_version >= version.parse("2.6.0"): _flash_attn_backward(grad_global_attn_output[q_compute_chunk_idx].get_gpu_chunk(), global_q[q_compute_chunk_idx].get_gpu_chunk(), global_k[kv_compute_chunk_idx].get_gpu_chunk(), @@ -785,7 +817,7 @@ def backward(ctx, grad_output): alibi_slopes=alibi_slopes, deterministic=False, rng_state=None) - except TypeError: + else: _flash_attn_backward(grad_global_attn_output[q_compute_chunk_idx].get_gpu_chunk(), global_q[q_compute_chunk_idx].get_gpu_chunk(), global_k[kv_compute_chunk_idx].get_gpu_chunk(), From ebaf56c9f5e1c55a4a42ffb1fd5795454c4bfb68 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 6 Nov 2024 17:34:51 +0000 Subject: [PATCH 38/54] add unit test directory as test trigger --- .github/workflows/nv-flash-attn.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/nv-flash-attn.yml b/.github/workflows/nv-flash-attn.yml index a7455ff4e5d8..8bff9d210e1e 100644 --- a/.github/workflows/nv-flash-attn.yml +++ b/.github/workflows/nv-flash-attn.yml @@ -5,6 +5,7 @@ on: pull_request: paths: - 'deepspeed/sequence/**' + - 'tests/unit/sequence_parallelism/**' - '.github/workflows/nv-flash-attn.yml' concurrency: group: ${{ github.workflow }}-${{ github.ref }} From 9e811b8da7fca89be4f2db5ba47c0a95f7a5457d Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 6 Nov 2024 17:42:36 +0000 Subject: [PATCH 39/54] add cron for test and reporting for nightly CI failures --- .github/workflows/nv-flash-attn.yml | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/.github/workflows/nv-flash-attn.yml b/.github/workflows/nv-flash-attn.yml index 8bff9d210e1e..310972323043 100644 --- a/.github/workflows/nv-flash-attn.yml +++ b/.github/workflows/nv-flash-attn.yml @@ -7,6 +7,9 @@ on: - 'deepspeed/sequence/**' - 'tests/unit/sequence_parallelism/**' - '.github/workflows/nv-flash-attn.yml' + schedule: + - cron: "0 0 * * *" + concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true @@ -51,3 +54,11 @@ jobs: unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests python -m pytest --color=yes --durations=0 --verbose -rF unit/sequence_parallelism/test_ulysses.py --torch_ver="2.3" --cuda_ver="12" + - name: Open GitHub issue if nightly CI fails + if: ${{ failure() && (github.event_name == 'schedule') }} + uses: JasonEtco/create-an-issue@v2 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + filename: .github/ISSUE_TEMPLATE/ci_failure_report.md + update_existing: true From a7522da82402ff78838ea671d24cafec0a938125 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Thu, 7 Nov 2024 15:48:19 -0500 Subject: [PATCH 40/54] add multiGPU fpdt unit test --- deepspeed/sequence/fpdt_layer.py | 9 ++-- .../unit/sequence_parallelism/test_ulysses.py | 45 ++++++++++++++----- 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index 5d574aa153b8..4fab768ce63c 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -117,15 +117,16 @@ def generate(self): gather_indices = indices[:, 0] token_chunk_indices = indices[:, 1] indices = torch.cat([token_chunk_indices[gather_indices == i] for i in range(gather_chunk.shape[0])]) - load_balanced_loss_mask = self.loss_mask[:, indices] + load_balanced_loss_mask = self.loss_mask[:, indices] if self.loss_mask is not None else self.loss_mask indices = indices.reshape(-1, self.chunk_size)[self.num_chunk_per_gpu * self.sp_rank:self.num_chunk_per_gpu * (self.sp_rank + 1)].flatten().contiguous() load_balanced_tokens = self.tokens[:, indices] - load_balanced_labels = self.labels[:, indices] + load_balanced_labels = self.labels[:, indices] if self.labels is not None else self.labels - load_balanced_attention_mask = self.attention_mask if self.attention_mask is not None else None - load_balanced_position_ids = self.position_ids[:, indices] + load_balanced_attention_mask = self.attention_mask if self.attention_mask is not None else self.attention_mask + load_balanced_position_ids = self.position_ids[:, + indices] if self.position_ids is not None else self.position_ids return load_balanced_tokens, load_balanced_labels, load_balanced_loss_mask, load_balanced_attention_mask, load_balanced_position_ids diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 5a4b41f8636e..641c1958fa8f 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -6,13 +6,12 @@ import pytest import torch import torch.nn.functional as F -from torch.nn.parameter import Parameter import deepspeed.comm as dist from deepspeed import initialize from transformers import AutoModel from unit.common import DistributedTest from deepspeed.sequence.layer import _SeqAllToAll -from deepspeed.sequence.fpdt_layer import _FPDTGPUOffloadingAttentionImpl_ +from deepspeed.sequence.fpdt_layer import _FPDTGPUOffloadingAttentionImpl_, FPDT_InputConstruct from unit.util import skip_on_arch from unit.simple_model import * from deepspeed.utils import groups @@ -162,13 +161,13 @@ def seq_batch_heads_hash(d0, d1, h, offset_d0=0, offset_d1=0, offset_h=0): outputs[i]), f"[{dist.get_rank()}]Outputs differ for sequence dim {seq_dims[i]}" -@pytest.mark.parametrize("d0", [4, 1]) #batch dimension -@pytest.mark.parametrize("d1", [2048, 8192]) #sequence dimension -@pytest.mark.parametrize("chunk_size", [512, 1024]) #size of chunk +@pytest.mark.parametrize("d0", [4]) #batch dimension +@pytest.mark.parametrize("d1", [2048]) #sequence dimension +@pytest.mark.parametrize("chunk_size", [128]) #size of chunk @pytest.mark.parametrize("num_heads", [8]) @pytest.mark.parametrize("head_dim", [32]) class TestFPDTAttention(DistributedTest): - world_size = 1 + world_size = 4 def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, chunk_size: int, head_dim: int, num_heads: int) -> None: @@ -189,20 +188,39 @@ def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, ch config_params={ "train_batch_size": 8, "data_parallel_size": 1, - "sequence_parallel_size": 1 + "sequence_parallel_size": world_size }, ) #3D tensor : l, b, d dim = head_dim * num_heads + + seed = 42 + torch.manual_seed(seed) + get_accelerator().manual_seed_all(seed) + input_tensor = torch.randn(d1, d0, dim, device=ds_engine.device, dtype=torch.half) - spg = ds_engine.data_parallel_group + spg = ds_engine.seq_parallel_group - qkv_linear_weight = Parameter(torch.empty(dim + 2 * dim, dim, device=ds_engine.device, dtype=torch.half)) + fpdt_input_tensor = FPDT_InputConstruct(input_tensor.permute(1, 0, 2), None, None, None, None, None, + world_size, dist.get_rank()).generate().permute(1, 0, 2) - qkv_linear_bias = Parameter(torch.empty(dim + 2 * dim, device=ds_engine.device, dtype=torch.half)) + if rank == 0: + qkv_linear_weight = torch.nn.Parameter( + torch.empty(dim + 2 * dim, dim, device=dist.get_rank(), dtype=torch.half)) + torch.nn.init.normal_(qkv_linear_weight, mean=0.0, std=0.02) - num_chunks_attn = input_tensor.shape[0] * dist.get_world_size(spg) // chunk_size - fpdt_output = _FPDTGPUOffloadingAttentionImpl_.apply(input_tensor, None, None, None, spg, 2, 0, dim, dim, + qkv_linear_bias = torch.nn.Parameter(torch.empty(dim + 2 * dim, device=dist.get_rank(), dtype=torch.half)) + torch.nn.init.normal_(qkv_linear_bias, mean=0.0, std=0.02) + else: + qkv_linear_weight = torch.nn.Parameter( + torch.empty(dim + 2 * dim, dim, device=dist.get_rank(), dtype=torch.half)) + qkv_linear_bias = torch.nn.Parameter(torch.empty(dim + 2 * dim, device=dist.get_rank(), dtype=torch.half)) + + dist.broadcast(qkv_linear_weight, src=0, group=spg) + dist.broadcast(qkv_linear_bias, src=0, group=spg) + + num_chunks_attn = fpdt_input_tensor.shape[0] * dist.get_world_size(spg) // chunk_size + fpdt_output = _FPDTGPUOffloadingAttentionImpl_.apply(fpdt_input_tensor, None, None, None, spg, 2, 0, dim, dim, head_dim, dim, qkv_linear_weight, qkv_linear_bias, 0, num_chunks_attn, True) @@ -224,4 +242,7 @@ def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, ch attn_weights = F.softmax(scores, dim=-1) output = torch.matmul(attn_weights, v).permute(0, 2, 1, 3) + baseline_output_shuffled = FPDT_InputConstruct(output, None, None, None, None, None, world_size, + dist.get_rank()).generate() # b, l, n, d + assert torch.allclose(fpdt_output, output), f"{torch.max(torch.abs(fpdt_output - output))}" From 209adab411308afa58beca60ebe9d83f423fd5b2 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Thu, 7 Nov 2024 15:51:51 -0500 Subject: [PATCH 41/54] add multiGPU fpdt unit test --- tests/unit/sequence_parallelism/test_ulysses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 641c1958fa8f..4bccb2fa8ca1 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -204,7 +204,7 @@ def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, ch fpdt_input_tensor = FPDT_InputConstruct(input_tensor.permute(1, 0, 2), None, None, None, None, None, world_size, dist.get_rank()).generate().permute(1, 0, 2) - if rank == 0: + if dist.get_rank() == 0: qkv_linear_weight = torch.nn.Parameter( torch.empty(dim + 2 * dim, dim, device=dist.get_rank(), dtype=torch.half)) torch.nn.init.normal_(qkv_linear_weight, mean=0.0, std=0.02) From dbeea8a65d4cf4dba9411b0440b0e513186bfe03 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Thu, 7 Nov 2024 15:57:57 -0500 Subject: [PATCH 42/54] add multiGPU fpdt unit test --- tests/unit/sequence_parallelism/test_ulysses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 4bccb2fa8ca1..ee02db0a484c 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -167,7 +167,7 @@ def seq_batch_heads_hash(d0, d1, h, offset_d0=0, offset_d1=0, offset_h=0): @pytest.mark.parametrize("num_heads", [8]) @pytest.mark.parametrize("head_dim", [32]) class TestFPDTAttention(DistributedTest): - world_size = 4 + world_size = 2 def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, chunk_size: int, head_dim: int, num_heads: int) -> None: From 845e42d5fbc733935ff96333d84f2713a36e648b Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Thu, 7 Nov 2024 16:02:31 -0500 Subject: [PATCH 43/54] add multiGPU fpdt unit test --- tests/unit/sequence_parallelism/test_ulysses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index ee02db0a484c..236331741d5b 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -167,11 +167,11 @@ def seq_batch_heads_hash(d0, d1, h, offset_d0=0, offset_d1=0, offset_h=0): @pytest.mark.parametrize("num_heads", [8]) @pytest.mark.parametrize("head_dim", [32]) class TestFPDTAttention(DistributedTest): - world_size = 2 def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, chunk_size: int, head_dim: int, num_heads: int) -> None: skip_on_arch(min_arch=8) + world_size = 2 try: from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward From 8b2549c91118533120c76deaef2447ad40e564c1 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Thu, 7 Nov 2024 16:11:32 -0500 Subject: [PATCH 44/54] add multiGPU fpdt unit test --- tests/unit/sequence_parallelism/test_ulysses.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 236331741d5b..240909ea6f06 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -201,7 +201,12 @@ def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, ch input_tensor = torch.randn(d1, d0, dim, device=ds_engine.device, dtype=torch.half) spg = ds_engine.seq_parallel_group - fpdt_input_tensor = FPDT_InputConstruct(input_tensor.permute(1, 0, 2), None, None, None, None, None, + class args: + + def __init__(self): + self.ds_sequence_parallel_fpdt_chunk_size = chunk_size + + fpdt_input_tensor = FPDT_InputConstruct(input_tensor.permute(1, 0, 2), None, None, None, None, args(), world_size, dist.get_rank()).generate().permute(1, 0, 2) if dist.get_rank() == 0: From 058c973d8410acd348224bd67510fb1fab7640af Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Thu, 7 Nov 2024 16:13:57 -0500 Subject: [PATCH 45/54] add multiGPU fpdt unit test --- tests/unit/sequence_parallelism/test_ulysses.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 240909ea6f06..88fadadd7ce1 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -207,7 +207,7 @@ def __init__(self): self.ds_sequence_parallel_fpdt_chunk_size = chunk_size fpdt_input_tensor = FPDT_InputConstruct(input_tensor.permute(1, 0, 2), None, None, None, None, args(), - world_size, dist.get_rank()).generate().permute(1, 0, 2) + world_size, dist.get_rank()).generate()[0].permute(1, 0, 2) if dist.get_rank() == 0: qkv_linear_weight = torch.nn.Parameter( @@ -248,6 +248,6 @@ def __init__(self): output = torch.matmul(attn_weights, v).permute(0, 2, 1, 3) baseline_output_shuffled = FPDT_InputConstruct(output, None, None, None, None, None, world_size, - dist.get_rank()).generate() # b, l, n, d + dist.get_rank()).generate()[0] # b, l, n, d assert torch.allclose(fpdt_output, output), f"{torch.max(torch.abs(fpdt_output - output))}" From 0dcc234d012c1dd88a5f6efa29a814870ff8ee0e Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Thu, 7 Nov 2024 16:16:00 -0500 Subject: [PATCH 46/54] add multiGPU fpdt unit test --- tests/unit/sequence_parallelism/test_ulysses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 88fadadd7ce1..aaaba638e6c7 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -247,7 +247,7 @@ def __init__(self): attn_weights = F.softmax(scores, dim=-1) output = torch.matmul(attn_weights, v).permute(0, 2, 1, 3) - baseline_output_shuffled = FPDT_InputConstruct(output, None, None, None, None, None, world_size, + baseline_output_shuffled = FPDT_InputConstruct(output, None, None, None, None, args(), world_size, dist.get_rank()).generate()[0] # b, l, n, d assert torch.allclose(fpdt_output, output), f"{torch.max(torch.abs(fpdt_output - output))}" From d1be5d356c52d9b5fcb877f93012c2ece0bbfa1e Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Thu, 7 Nov 2024 16:18:18 -0500 Subject: [PATCH 47/54] add multiGPU fpdt unit test --- tests/unit/sequence_parallelism/test_ulysses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index aaaba638e6c7..a7d746bdbb2f 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -250,4 +250,4 @@ def __init__(self): baseline_output_shuffled = FPDT_InputConstruct(output, None, None, None, None, args(), world_size, dist.get_rank()).generate()[0] # b, l, n, d - assert torch.allclose(fpdt_output, output), f"{torch.max(torch.abs(fpdt_output - output))}" + assert torch.allclose(fpdt_output, baseline_output_shuffled), f"{torch.max(torch.abs(fpdt_output - baseline_output_shuffled))}" From 3a0feba38dbc42495413e20076a551fe32f18b48 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Thu, 7 Nov 2024 16:26:31 -0500 Subject: [PATCH 48/54] add multiGPU fpdt unit test --- tests/unit/sequence_parallelism/test_ulysses.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index a7d746bdbb2f..701f9476bdbf 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -171,7 +171,7 @@ class TestFPDTAttention(DistributedTest): def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, chunk_size: int, head_dim: int, num_heads: int) -> None: skip_on_arch(min_arch=8) - world_size = 2 + world_size = 1 try: from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward @@ -199,7 +199,7 @@ def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, ch get_accelerator().manual_seed_all(seed) input_tensor = torch.randn(d1, d0, dim, device=ds_engine.device, dtype=torch.half) - spg = ds_engine.seq_parallel_group + spg = ds_engine.data_parallel_group class args: From 8c57812cc017e4f5d0172ebac338932bd6306df1 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Thu, 7 Nov 2024 16:37:07 -0500 Subject: [PATCH 49/54] add multiGPU fpdt unit test --- tests/unit/sequence_parallelism/test_ulysses.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 701f9476bdbf..224769643424 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -171,7 +171,7 @@ class TestFPDTAttention(DistributedTest): def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, chunk_size: int, head_dim: int, num_heads: int) -> None: skip_on_arch(min_arch=8) - world_size = 1 + world_size = 2 try: from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward @@ -199,7 +199,7 @@ def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, ch get_accelerator().manual_seed_all(seed) input_tensor = torch.randn(d1, d0, dim, device=ds_engine.device, dtype=torch.half) - spg = ds_engine.data_parallel_group + spg = ds_engine.seq_parallel_group class args: @@ -250,4 +250,4 @@ def __init__(self): baseline_output_shuffled = FPDT_InputConstruct(output, None, None, None, None, args(), world_size, dist.get_rank()).generate()[0] # b, l, n, d - assert torch.allclose(fpdt_output, baseline_output_shuffled), f"{torch.max(torch.abs(fpdt_output - baseline_output_shuffled))}" + assert torch.allclose(fpdt_output, baseline_output_shuffled), f"sp size: {dist.get_world_size(spg)}, input_tensor: {input_tensor.shape}, fpdt_input_tensor: {fpdt_input_tensor.shape}, fpdt_output: {fpdt_output.shape}, baseline_output_shuffled: {baseline_output_shuffled.shape},{torch.max(torch.abs(fpdt_output - baseline_output_shuffled))}" From 43decf6ffe0f5296afe172beedbde2fa82ebbba2 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Thu, 7 Nov 2024 19:37:10 -0500 Subject: [PATCH 50/54] add multiGPU fpdt unit test --- tests/unit/sequence_parallelism/test_ulysses.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 224769643424..f5576393d223 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -199,6 +199,7 @@ def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, ch get_accelerator().manual_seed_all(seed) input_tensor = torch.randn(d1, d0, dim, device=ds_engine.device, dtype=torch.half) + dist.broadcast(input_tensor, src=0, group=spg) spg = ds_engine.seq_parallel_group class args: @@ -250,4 +251,4 @@ def __init__(self): baseline_output_shuffled = FPDT_InputConstruct(output, None, None, None, None, args(), world_size, dist.get_rank()).generate()[0] # b, l, n, d - assert torch.allclose(fpdt_output, baseline_output_shuffled), f"sp size: {dist.get_world_size(spg)}, input_tensor: {input_tensor.shape}, fpdt_input_tensor: {fpdt_input_tensor.shape}, fpdt_output: {fpdt_output.shape}, baseline_output_shuffled: {baseline_output_shuffled.shape},{torch.max(torch.abs(fpdt_output - baseline_output_shuffled))}" + assert torch.allclose(fpdt_output, baseline_output_shuffled), f"rank {dist.get_rank()}, sp size: {dist.get_world_size(spg)}, input_tensor: {input_tensor.shape}, fpdt_input_tensor: {fpdt_input_tensor.shape}, fpdt_output: {fpdt_output.shape}, baseline_output_shuffled: {baseline_output_shuffled.shape},{torch.max(torch.abs(fpdt_output - baseline_output_shuffled))}" From d39585c6d0f2b6b428632800bacf4b6530292438 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Thu, 7 Nov 2024 19:40:27 -0500 Subject: [PATCH 51/54] add multiGPU fpdt unit test --- tests/unit/sequence_parallelism/test_ulysses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index f5576393d223..1b15b7754758 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -199,9 +199,9 @@ def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, ch get_accelerator().manual_seed_all(seed) input_tensor = torch.randn(d1, d0, dim, device=ds_engine.device, dtype=torch.half) - dist.broadcast(input_tensor, src=0, group=spg) spg = ds_engine.seq_parallel_group + dist.broadcast(input_tensor, src=0, group=spg) class args: def __init__(self): From 389b1a3b1f8a1bea4db21f4726d9cb71fa99c061 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Thu, 7 Nov 2024 20:02:50 -0500 Subject: [PATCH 52/54] add multiGPU fpdt unit test --- tests/unit/sequence_parallelism/test_ulysses.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 1b15b7754758..775c386314ee 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -198,7 +198,7 @@ def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, ch torch.manual_seed(seed) get_accelerator().manual_seed_all(seed) - input_tensor = torch.randn(d1, d0, dim, device=ds_engine.device, dtype=torch.half) + input_tensor = torch.randn(d1, d0, dim, device=ds_engine.device, dtype=torch.half) # l, b, d spg = ds_engine.seq_parallel_group dist.broadcast(input_tensor, src=0, group=spg) @@ -251,4 +251,4 @@ def __init__(self): baseline_output_shuffled = FPDT_InputConstruct(output, None, None, None, None, args(), world_size, dist.get_rank()).generate()[0] # b, l, n, d - assert torch.allclose(fpdt_output, baseline_output_shuffled), f"rank {dist.get_rank()}, sp size: {dist.get_world_size(spg)}, input_tensor: {input_tensor.shape}, fpdt_input_tensor: {fpdt_input_tensor.shape}, fpdt_output: {fpdt_output.shape}, baseline_output_shuffled: {baseline_output_shuffled.shape},{torch.max(torch.abs(fpdt_output - baseline_output_shuffled))}" + assert torch.allclose(fpdt_output, baseline_output_shuffled, rtol=0.01, atol=0.1), f"rank {dist.get_rank()}, sp size: {dist.get_world_size(spg)}, input_tensor: {input_tensor.shape}, fpdt_input_tensor: {fpdt_input_tensor.shape}, fpdt_output: {fpdt_output.shape}, baseline_output_shuffled: {baseline_output_shuffled.shape},{torch.max(torch.abs(fpdt_output - baseline_output_shuffled))}" From 958f3bf2e54f4a935c5f16af70f637bb608c14c7 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Thu, 7 Nov 2024 20:05:06 -0500 Subject: [PATCH 53/54] add multiGPU fpdt unit test --- tests/unit/sequence_parallelism/test_ulysses.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 775c386314ee..3e5ff1c34d94 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -161,11 +161,11 @@ def seq_batch_heads_hash(d0, d1, h, offset_d0=0, offset_d1=0, offset_h=0): outputs[i]), f"[{dist.get_rank()}]Outputs differ for sequence dim {seq_dims[i]}" -@pytest.mark.parametrize("d0", [4]) #batch dimension -@pytest.mark.parametrize("d1", [2048]) #sequence dimension -@pytest.mark.parametrize("chunk_size", [128]) #size of chunk -@pytest.mark.parametrize("num_heads", [8]) -@pytest.mark.parametrize("head_dim", [32]) +@pytest.mark.parametrize("d0", [4, 1]) #batch dimension +@pytest.mark.parametrize("d1", [2048, 8192]) #sequence dimension +@pytest.mark.parametrize("chunk_size", [128, 256]) #size of chunk +@pytest.mark.parametrize("num_heads", [8, 4]) +@pytest.mark.parametrize("head_dim", [32, 64]) class TestFPDTAttention(DistributedTest): def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, chunk_size: int, head_dim: int, From af025c5102106a292c5bf46e02cdf4495dd57cdb Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Thu, 7 Nov 2024 20:08:29 -0500 Subject: [PATCH 54/54] add multiGPU fpdt unit test --- tests/unit/sequence_parallelism/test_ulysses.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 3e5ff1c34d94..821847c44265 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -165,7 +165,7 @@ def seq_batch_heads_hash(d0, d1, h, offset_d0=0, offset_d1=0, offset_h=0): @pytest.mark.parametrize("d1", [2048, 8192]) #sequence dimension @pytest.mark.parametrize("chunk_size", [128, 256]) #size of chunk @pytest.mark.parametrize("num_heads", [8, 4]) -@pytest.mark.parametrize("head_dim", [32, 64]) +@pytest.mark.parametrize("head_dim", [32]) class TestFPDTAttention(DistributedTest): def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, chunk_size: int, head_dim: int, @@ -198,10 +198,11 @@ def test_FPDT_attention_offloading_output_consistency(self, d0: int, d1: int, ch torch.manual_seed(seed) get_accelerator().manual_seed_all(seed) - input_tensor = torch.randn(d1, d0, dim, device=ds_engine.device, dtype=torch.half) # l, b, d + input_tensor = torch.randn(d1, d0, dim, device=ds_engine.device, dtype=torch.half) # l, b, d spg = ds_engine.seq_parallel_group dist.broadcast(input_tensor, src=0, group=spg) + class args: def __init__(self): @@ -251,4 +252,6 @@ def __init__(self): baseline_output_shuffled = FPDT_InputConstruct(output, None, None, None, None, args(), world_size, dist.get_rank()).generate()[0] # b, l, n, d - assert torch.allclose(fpdt_output, baseline_output_shuffled, rtol=0.01, atol=0.1), f"rank {dist.get_rank()}, sp size: {dist.get_world_size(spg)}, input_tensor: {input_tensor.shape}, fpdt_input_tensor: {fpdt_input_tensor.shape}, fpdt_output: {fpdt_output.shape}, baseline_output_shuffled: {baseline_output_shuffled.shape},{torch.max(torch.abs(fpdt_output - baseline_output_shuffled))}" + assert torch.allclose( + fpdt_output, baseline_output_shuffled, rtol=0.01, atol=0.1 + ), f"rank {dist.get_rank()}, sp size: {dist.get_world_size(spg)}, input_tensor: {input_tensor.shape}, fpdt_input_tensor: {fpdt_input_tensor.shape}, fpdt_output: {fpdt_output.shape}, baseline_output_shuffled: {baseline_output_shuffled.shape},{torch.max(torch.abs(fpdt_output - baseline_output_shuffled))}"