From 7f40d592d69bd98d0577fdbe748164b112acec1f Mon Sep 17 00:00:00 2001
From: ChenQiaoling00 <qiaoling_chen@u.nus.edu>
Date: Mon, 18 Mar 2024 16:01:50 +0800
Subject: [PATCH 1/2] add gradient sharding

---
 configs/7B_llama2.py                          |  8 ++++----
 internlm/core/communication/isp.py            |  2 ++
 .../core/scheduler/no_pipeline_scheduler.py   |  8 ++++----
 .../solver/optimizer/hybrid_zero_optim.py     | 19 +++++++++++++++++--
 4 files changed, 27 insertions(+), 10 deletions(-)

diff --git a/configs/7B_llama2.py b/configs/7B_llama2.py
index 702429462..e9353f72d 100644
--- a/configs/7B_llama2.py
+++ b/configs/7B_llama2.py
@@ -45,7 +45,7 @@
 data = dict(
     seq_len=SEQ_LEN,
     # micro_num means the number of micro_batch contained in one gradient update
-    micro_num=4,
+    micro_num=2,
     # packed_length = micro_bsz * SEQ_LEN
     micro_bsz=1,
     # defaults to the value of micro_num
@@ -172,10 +172,10 @@
     3. memory_pool: bool, enable/disable memory pool, defaults to False.
 """
 parallel = dict(
-    zero1=dict(size=-1),
-    tensor=dict(size=1, mode="mtp"),
+    zero1=dict(size=4),
+    tensor=dict(size=1, mode="isp"),
     pipeline=dict(size=1, interleaved_overlap=True),
-    weight=dict(size=1, overlap=True, memory_pool=True),
+    weight=dict(size=2, overlap=True, memory_pool=True),
 )
 
 cudnn_deterministic = False
diff --git a/internlm/core/communication/isp.py b/internlm/core/communication/isp.py
index 8042f7763..16db096c9 100644
--- a/internlm/core/communication/isp.py
+++ b/internlm/core/communication/isp.py
@@ -559,6 +559,8 @@ def before_backward(self, scheduler, outputs, outputs_grad) -> None:
     def after_backward(self, scheduler, inputs_grad) -> None:
         # accumulate left gradients in last bucket after backward.
         self._zero_optim.accumulate_left_grads_after_backward()
+
+        self._zero_optim.reduce_grad_by_bucket_after_backward()
         # reset lazy memory pools for reduce scatter after every micro step.
         if self._isp_communicator and self._isp_communicator.enable_memory_pool:
             self._isp_communicator.memory_pool.reset_lazy_pools()
diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py
index 0cd8c1030..ebf887d07 100644
--- a/internlm/core/scheduler/no_pipeline_scheduler.py
+++ b/internlm/core/scheduler/no_pipeline_scheduler.py
@@ -195,10 +195,10 @@ def forward_backward_step(
 
         for _current_accum_step in range(self._grad_accum_size):
             if engine.optimizer is not None:
-                if _current_accum_step == self._grad_accum_size - 1:
-                    engine.optimizer.skip_grad_reduce = False
-                else:
-                    engine.optimizer.skip_grad_reduce = True
+                # if _current_accum_step == self._grad_accum_size - 1:
+                engine.optimizer.skip_grad_reduce = False
+                # else:
+                #     engine.optimizer.skip_grad_reduce = True
 
             _data, _label = self._load_accum_batch(data, label)
 
diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py
index a31cadae8..6cdd2feea 100644
--- a/internlm/solver/optimizer/hybrid_zero_optim.py
+++ b/internlm/solver/optimizer/hybrid_zero_optim.py
@@ -376,6 +376,11 @@ def accumulate_left_grads_after_backward(self):
 
         for group_id in range(self.num_param_groups):
             self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id])
+            
+    def reduce_grad_by_bucket_after_backward(self):
+        for group_id in range(self.num_param_groups):
+            self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True)
+            
 
     def belongs_to_current_rank(self, param) -> bool:
         """
@@ -481,8 +486,18 @@ def _reduce_grads_stored_in_bucket(self, current_bucket, reduce_rank=None, last_
                 raise RuntimeError(msg)
 
             # update the flag
-            self._param_store.set_param_reduction_state(param, True)
-
+            
+            if last_bucket==True:
+                # self._param_store.clear_grads_of_previous_reduced_params()
+                # self._param_store.set_param_reduction_state(param, False)
+                for group_id, param_group in enumerate(self.optim.param_groups):
+                    for param in self._fp16_param_groups[group_id]:
+                        self._param_store.set_param_reduction_state(param, False)
+                self._param_store.clear_grads_of_previous_reduced_params()
+                
+            else: 
+                self._param_store.set_param_reduction_state(param, True)
+            
             if self.belongs_to_current_rank(param):
                 self._param_store.add_reduced_param_for_compute_norm(param, last_bucket)
             else:

From ff96325fde87db32b6baa90a9fef31aed701b55e Mon Sep 17 00:00:00 2001
From: ChenQiaoling00 <qiaoling_chen@u.nus.edu>
Date: Mon, 18 Mar 2024 16:48:11 +0800
Subject: [PATCH 2/2] updates

---
 internlm/core/communication/isp.py             | 1 -
 internlm/solver/optimizer/hybrid_zero_optim.py | 2 +-
 2 files changed, 1 insertion(+), 2 deletions(-)

diff --git a/internlm/core/communication/isp.py b/internlm/core/communication/isp.py
index 16db096c9..369edd736 100644
--- a/internlm/core/communication/isp.py
+++ b/internlm/core/communication/isp.py
@@ -559,7 +559,6 @@ def before_backward(self, scheduler, outputs, outputs_grad) -> None:
     def after_backward(self, scheduler, inputs_grad) -> None:
         # accumulate left gradients in last bucket after backward.
         self._zero_optim.accumulate_left_grads_after_backward()
-
         self._zero_optim.reduce_grad_by_bucket_after_backward()
         # reset lazy memory pools for reduce scatter after every micro step.
         if self._isp_communicator and self._isp_communicator.enable_memory_pool:
diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py
index 6cdd2feea..c30a48bb1 100644
--- a/internlm/solver/optimizer/hybrid_zero_optim.py
+++ b/internlm/solver/optimizer/hybrid_zero_optim.py
@@ -493,7 +493,7 @@ def _reduce_grads_stored_in_bucket(self, current_bucket, reduce_rank=None, last_
                 for group_id, param_group in enumerate(self.optim.param_groups):
                     for param in self._fp16_param_groups[group_id]:
                         self._param_store.set_param_reduction_state(param, False)
-                self._param_store.clear_grads_of_previous_reduced_params()
+                # self._param_store.clear_grads_of_previous_reduced_params()
                 
             else: 
                 self._param_store.set_param_reduction_state(param, True)