From 183cabb6bd4755f0d7f7ae42c4411d891d1f3371 Mon Sep 17 00:00:00 2001
From: Alexpan <zhoudaoxian@foxmail.com>
Date: Wed, 16 Aug 2023 12:13:40 +0800
Subject: [PATCH 1/2] Update ddpm_multi.py

---
 ldm/models/diffusion/ddpm_multi.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/ldm/models/diffusion/ddpm_multi.py b/ldm/models/diffusion/ddpm_multi.py
index 516a65f..e587402 100644
--- a/ldm/models/diffusion/ddpm_multi.py
+++ b/ldm/models/diffusion/ddpm_multi.py
@@ -27,7 +27,9 @@
 import itertools
 from tqdm import tqdm
 from torchvision.utils import make_grid
-from pytorch_lightning.utilities.distributed import rank_zero_only
+# from pytorch_lightning.utilities.distributed import rank_zero_only
+from pytorch_lightning.utilities.rank_zero import rank_zero_only
+
 from omegaconf import ListConfig
 
 from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config

From e7fc6ef367152c2f30033ce850e88af218ef9976 Mon Sep 17 00:00:00 2001
From: Alexpan <zhoudaoxian@foxmail.com>
Date: Wed, 16 Aug 2023 12:15:29 +0800
Subject: [PATCH 2/2] Update cldm_unicontrol.py

---
 cldm/cldm_unicontrol.py | 121 ++++++++++++++++++++++++++++++++++++----
 1 file changed, 111 insertions(+), 10 deletions(-)

diff --git a/cldm/cldm_unicontrol.py b/cldm/cldm_unicontrol.py
index b776cc7..df28043 100644
--- a/cldm/cldm_unicontrol.py
+++ b/cldm/cldm_unicontrol.py
@@ -362,35 +362,121 @@ def forward(self, x, hint, timesteps, context, **kwargs):
         hint -> 4, 3, 512, 512
         context - > 4, 77, 768
         '''
+        # x=2,4,64.64
+        # hint=[2.6,512,512]
+        # context=2,77,768
+        
         BS = 1 # x.shape[0], one batch one task
         BS_Real = x.shape[0]
+        all_tasks = []
+        task_scale = []
+        split_switch = 0
+        all_task_id_emb = []
         if kwargs is not None:
             task_name = kwargs['task']['name']
             task_id = self.tasks_to_id[task_name]
             task_feature = kwargs['task']['feature']
             task_id_emb = self.task_id_hypernet(task_feature.squeeze(0))
+            if "all_tasks" in kwargs['task']:
+                all_tasks = kwargs['task']['all_tasks']
+                task_scale = kwargs['task']['task_scale']
+                if "split_switch" in kwargs['task']:
+                    split_switch = kwargs['task']['split_switch']
+                if "all_task_feature" in kwargs['task']:
+                    all_task_feature = kwargs['task']['all_task_feature']
+                    for sub_tast_feature in all_task_feature:
+                        all_task_id_emb.append(self.task_id_hypernet(sub_tast_feature.squeeze(0)))
+
+            # 1,1280
             
         t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
         emb = self.time_embed(t_emb)
-        guided_hint = self.input_hint_block_list_moe[task_id](hint, emb, context)
-
-        guided_hint = modulated_conv2d(guided_hint, self.input_hint_block_zeroconv_0[0].weight, self.task_id_layernet_zeroconv_0(task_id_emb).repeat(BS_Real, 1).detach(), padding=1)
-        guided_hint += self.input_hint_block_zeroconv_0[0].bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
+        # --- by ri chi --------
+        # guided_hint0 = self.input_hint_block_list_moe[1](hint[0], emb, context)
+        # guided_hint1 = self.input_hint_block_list_moe[3](hint[1], emb, context)
+        # guided_hint2 = self.input_hint_block_list_moe[4](hint[2], emb, context)
+        # # import pdb;pdb.set_trace()
+        # guided_hint0=self.deal_hint(context, BS_Real, task_id_emb, emb,guided_hint0)
+        # guided_hint1=self.deal_hint(context, BS_Real, task_id_emb, emb,guided_hint1)
+        # guided_hint2=self.deal_hint(context, BS_Real, task_id_emb, emb,guided_hint2)
+        # guided_hint=0.4*guided_hint0+0.2*guided_hint1+0.3*guided_hint2
+        #  ---- end by rich ----
+        # print("all_tasks:", all_tasks, "hint had ", len(hint), "hint[0].shape", hint[0].shape)
+        guided_hint_final_list = []
+        num_tasks_weight = 1.0 / (len(all_tasks)+0.00001)
+        # for task in all_tasks:
+        for sub_hint_id, task in enumerate(all_tasks):
+            task_id = self.tasks_to_id[task]  #  1 3
+            sub_guided_hint = self.input_hint_block_list_moe[task_id](hint[sub_hint_id], emb, context)
+            # guided_hint=self.deal_hint(context, BS_Real, task_id_emb, emb, sub_guided_hint)
+            guided_hint=self.deal_hint(context, BS_Real, all_task_id_emb[sub_hint_id], emb, sub_guided_hint)
+            num_tasks_weight = task_scale[sub_hint_id]
+            guided_hint = num_tasks_weight * guided_hint
+            guided_hint_final_list.append(guided_hint)
+            del guided_hint
         
-        guided_hint = self.input_hint_block_share(guided_hint, emb, context)
-
-        guided_hint = modulated_conv2d(guided_hint, self.input_hint_block_zeroconv_1[0].weight, self.task_id_layernet_zeroconv_1(task_id_emb).repeat(BS_Real, 1).detach(), padding=1)
-        guided_hint += self.input_hint_block_zeroconv_1[0].bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
+        if len(all_tasks) > 1:
+            guided_hint_weight_sumed = guided_hint_final_list[0]
+            for i in range(1, len(guided_hint_final_list)):
+                guided_hint_weight_sumed = guided_hint_weight_sumed + guided_hint_final_list[i]
+            guided_hint = guided_hint_weight_sumed
+        else:
+            guided_hint0 = self.input_hint_block_list_moe[task_id](hint, emb, context)
+            guided_hint = self.deal_hint(context, BS_Real, task_id_emb, emb,guided_hint0)
+
+        # split_switch = 1  # for debug  ------------------------
+        all_outputs = []
+        if split_switch and len(all_tasks) > 1:  # 是否分次计算,然后合并----?----
+            for sub_hint_id, task in enumerate(all_tasks):
+                task_id = self.tasks_to_id[task]  #  1 3
+                sub_guided_hint = self.input_hint_block_list_moe[task_id](hint[sub_hint_id], emb, context)
+                guided_hint=self.deal_hint(context, BS_Real, all_task_id_emb[sub_hint_id], emb, sub_guided_hint)
+                num_tasks_weight = task_scale[sub_hint_id]
+                guided_hint = num_tasks_weight * guided_hint
+            # for guided_hint in guided_hint_final_list:
+                #  do the fellowing calculate
+                outs = []
+                h = x.type(self.dtype)
+                for module, zero_conv, task_hyperlayer in zip(self.input_blocks, self.zero_convs, self.task_id_layernet):
+                    if guided_hint is not None:
+                        h = module(h, emb, context)
+                        # print(h.shape)
+                        # print(guided_hint.shape)
+                        try:
+                            h += guided_hint
+                        except RuntimeError:
+                            # pdb.set_trace()
+                            continue
+                        guided_hint = None
+                    else:
+                        h = module(h, emb, context)
+                        
+                    # outs.append(modulated_conv2d(h, zero_conv[0].weight, task_hyperlayer(task_id_emb).repeat(BS_Real, 1).detach()) + zero_conv[0].bias.unsqueeze(0).unsqueeze(2).unsqueeze(3))
+                    outs.append(modulated_conv2d(h, zero_conv[0].weight, task_hyperlayer(all_task_id_emb[sub_hint_id]).repeat(BS_Real, 1).detach()) + zero_conv[0].bias.unsqueeze(0).unsqueeze(2).unsqueeze(3))
+                
+                h = self.middle_block(h, emb, context)
+                outs.append(self.middle_block_out(h, emb, context))
+                all_outputs.append(outs)
+            # final_outputs
+            final_outputs = all_outputs[0]
+            for sub_h in all_outputs[0:]:  # len(sub_h) 13  : [torch.Size([2, 320, 64, 96]), torch.Size([2, 1280, 8, 12])] * 13
+                for sub_index in range(len(sub_h)):
+                    final_outputs[sub_index] = torch.add(final_outputs[sub_index], sub_h[sub_index])
+                # final_outputs
+            return final_outputs
 
         outs = []
         h = x.type(self.dtype)
         for module, zero_conv, task_hyperlayer in zip(self.input_blocks, self.zero_convs, self.task_id_layernet):
             if guided_hint is not None:
                 h = module(h, emb, context)
+                # print(h.shape)
+                # print(guided_hint.shape)
                 try:
                     h += guided_hint
                 except RuntimeError:
-                    pdb.set_trace()
+                    # pdb.set_trace()
+                    continue
                 guided_hint = None
             else:
                 h = module(h, emb, context)
@@ -402,6 +488,16 @@ def forward(self, x, hint, timesteps, context, **kwargs):
 
         return outs
 
+    def deal_hint(self, context, BS_Real, task_id_emb, emb, guided_hint):
+        guided_hint = modulated_conv2d(guided_hint, self.input_hint_block_zeroconv_0[0].weight, self.task_id_layernet_zeroconv_0(task_id_emb).repeat(BS_Real, 1).detach(), padding=1)
+        guided_hint = guided_hint + self.input_hint_block_zeroconv_0[0].bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
+        
+        guided_hint = self.input_hint_block_share(guided_hint, emb, context)
+
+        guided_hint = modulated_conv2d(guided_hint, self.input_hint_block_zeroconv_1[0].weight, self.task_id_layernet_zeroconv_1(task_id_emb).repeat(BS_Real, 1).detach(), padding=1)
+        guided_hint = guided_hint + self.input_hint_block_zeroconv_1[0].bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
+        return guided_hint
+
 
 class ControlLDM(LatentDiffusion):
 
@@ -458,7 +554,12 @@ def apply_model(self, x_noisy, t, cond, *args, **kwargs):
         if cond['c_concat'] is None:
             eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
         else:
-            control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt, task=task_name)
+            # control = self.control_model(x=x_noisy, hint=[cond['c_concat'][0],cond['c_concat'][1],cond['c_concat'][2]], 
+                                        #  timesteps=t, context=cond_txt, task=task_name)
+            if len(cond['c_concat']) == 1:
+                control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt, task=task_name)
+            else:
+                control = self.control_model(x=x_noisy, hint=list(cond['c_concat']), timesteps=t, context=cond_txt, task=task_name)
             control = [c * scale for c, scale in zip(control, self.control_scales)]
             eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)