diff --git a/configs/7B_llama2.py b/configs/7B_llama2.py index 70242946..e9353f72 100644 --- a/configs/7B_llama2.py +++ b/configs/7B_llama2.py @@ -45,7 +45,7 @@ data = dict( seq_len=SEQ_LEN, # micro_num means the number of micro_batch contained in one gradient update - micro_num=4, + micro_num=2, # packed_length = micro_bsz * SEQ_LEN micro_bsz=1, # defaults to the value of micro_num @@ -172,10 +172,10 @@ 3. memory_pool: bool, enable/disable memory pool, defaults to False. """ parallel = dict( - zero1=dict(size=-1), - tensor=dict(size=1, mode="mtp"), + zero1=dict(size=4), + tensor=dict(size=1, mode="isp"), pipeline=dict(size=1, interleaved_overlap=True), - weight=dict(size=1, overlap=True, memory_pool=True), + weight=dict(size=2, overlap=True, memory_pool=True), ) cudnn_deterministic = False diff --git a/internlm/core/communication/isp.py b/internlm/core/communication/isp.py index 8042f776..369edd73 100644 --- a/internlm/core/communication/isp.py +++ b/internlm/core/communication/isp.py @@ -559,6 +559,7 @@ def before_backward(self, scheduler, outputs, outputs_grad) -> None: def after_backward(self, scheduler, inputs_grad) -> None: # accumulate left gradients in last bucket after backward. self._zero_optim.accumulate_left_grads_after_backward() + self._zero_optim.reduce_grad_by_bucket_after_backward() # reset lazy memory pools for reduce scatter after every micro step. if self._isp_communicator and self._isp_communicator.enable_memory_pool: self._isp_communicator.memory_pool.reset_lazy_pools() diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 0cd8c103..ebf887d0 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -195,10 +195,10 @@ def forward_backward_step( for _current_accum_step in range(self._grad_accum_size): if engine.optimizer is not None: - if _current_accum_step == self._grad_accum_size - 1: - engine.optimizer.skip_grad_reduce = False - else: - engine.optimizer.skip_grad_reduce = True + # if _current_accum_step == self._grad_accum_size - 1: + engine.optimizer.skip_grad_reduce = False + # else: + # engine.optimizer.skip_grad_reduce = True _data, _label = self._load_accum_batch(data, label) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index a31cadae..c30a48bb 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -376,6 +376,11 @@ def accumulate_left_grads_after_backward(self): for group_id in range(self.num_param_groups): self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id]) + + def reduce_grad_by_bucket_after_backward(self): + for group_id in range(self.num_param_groups): + self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True) + def belongs_to_current_rank(self, param) -> bool: """ @@ -481,8 +486,18 @@ def _reduce_grads_stored_in_bucket(self, current_bucket, reduce_rank=None, last_ raise RuntimeError(msg) # update the flag - self._param_store.set_param_reduction_state(param, True) - + + if last_bucket==True: + # self._param_store.clear_grads_of_previous_reduced_params() + # self._param_store.set_param_reduction_state(param, False) + for group_id, param_group in enumerate(self.optim.param_groups): + for param in self._fp16_param_groups[group_id]: + self._param_store.set_param_reduction_state(param, False) + # self._param_store.clear_grads_of_previous_reduced_params() + + else: + self._param_store.set_param_reduction_state(param, True) + if self.belongs_to_current_rank(param): self._param_store.add_reduced_param_for_compute_norm(param, last_bucket) else: