From 183cabb6bd4755f0d7f7ae42c4411d891d1f3371 Mon Sep 17 00:00:00 2001 From: Alexpan <zhoudaoxian@foxmail.com> Date: Wed, 16 Aug 2023 12:13:40 +0800 Subject: [PATCH 1/2] Update ddpm_multi.py --- ldm/models/diffusion/ddpm_multi.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ldm/models/diffusion/ddpm_multi.py b/ldm/models/diffusion/ddpm_multi.py index 516a65f..e587402 100644 --- a/ldm/models/diffusion/ddpm_multi.py +++ b/ldm/models/diffusion/ddpm_multi.py @@ -27,7 +27,9 @@ import itertools from tqdm import tqdm from torchvision.utils import make_grid -from pytorch_lightning.utilities.distributed import rank_zero_only +# from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.rank_zero import rank_zero_only + from omegaconf import ListConfig from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config From e7fc6ef367152c2f30033ce850e88af218ef9976 Mon Sep 17 00:00:00 2001 From: Alexpan <zhoudaoxian@foxmail.com> Date: Wed, 16 Aug 2023 12:15:29 +0800 Subject: [PATCH 2/2] Update cldm_unicontrol.py --- cldm/cldm_unicontrol.py | 121 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 111 insertions(+), 10 deletions(-) diff --git a/cldm/cldm_unicontrol.py b/cldm/cldm_unicontrol.py index b776cc7..df28043 100644 --- a/cldm/cldm_unicontrol.py +++ b/cldm/cldm_unicontrol.py @@ -362,35 +362,121 @@ def forward(self, x, hint, timesteps, context, **kwargs): hint -> 4, 3, 512, 512 context - > 4, 77, 768 ''' + # x=2,4,64.64 + # hint=[2.6,512,512] + # context=2,77,768 + BS = 1 # x.shape[0], one batch one task BS_Real = x.shape[0] + all_tasks = [] + task_scale = [] + split_switch = 0 + all_task_id_emb = [] if kwargs is not None: task_name = kwargs['task']['name'] task_id = self.tasks_to_id[task_name] task_feature = kwargs['task']['feature'] task_id_emb = self.task_id_hypernet(task_feature.squeeze(0)) + if "all_tasks" in kwargs['task']: + all_tasks = kwargs['task']['all_tasks'] + task_scale = kwargs['task']['task_scale'] + if "split_switch" in kwargs['task']: + split_switch = kwargs['task']['split_switch'] + if "all_task_feature" in kwargs['task']: + all_task_feature = kwargs['task']['all_task_feature'] + for sub_tast_feature in all_task_feature: + all_task_id_emb.append(self.task_id_hypernet(sub_tast_feature.squeeze(0))) + + # 1,1280 t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) emb = self.time_embed(t_emb) - guided_hint = self.input_hint_block_list_moe[task_id](hint, emb, context) - - guided_hint = modulated_conv2d(guided_hint, self.input_hint_block_zeroconv_0[0].weight, self.task_id_layernet_zeroconv_0(task_id_emb).repeat(BS_Real, 1).detach(), padding=1) - guided_hint += self.input_hint_block_zeroconv_0[0].bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) + # --- by ri chi -------- + # guided_hint0 = self.input_hint_block_list_moe[1](hint[0], emb, context) + # guided_hint1 = self.input_hint_block_list_moe[3](hint[1], emb, context) + # guided_hint2 = self.input_hint_block_list_moe[4](hint[2], emb, context) + # # import pdb;pdb.set_trace() + # guided_hint0=self.deal_hint(context, BS_Real, task_id_emb, emb,guided_hint0) + # guided_hint1=self.deal_hint(context, BS_Real, task_id_emb, emb,guided_hint1) + # guided_hint2=self.deal_hint(context, BS_Real, task_id_emb, emb,guided_hint2) + # guided_hint=0.4*guided_hint0+0.2*guided_hint1+0.3*guided_hint2 + # ---- end by rich ---- + # print("all_tasks:", all_tasks, "hint had ", len(hint), "hint[0].shape", hint[0].shape) + guided_hint_final_list = [] + num_tasks_weight = 1.0 / (len(all_tasks)+0.00001) + # for task in all_tasks: + for sub_hint_id, task in enumerate(all_tasks): + task_id = self.tasks_to_id[task] # 1 3 + sub_guided_hint = self.input_hint_block_list_moe[task_id](hint[sub_hint_id], emb, context) + # guided_hint=self.deal_hint(context, BS_Real, task_id_emb, emb, sub_guided_hint) + guided_hint=self.deal_hint(context, BS_Real, all_task_id_emb[sub_hint_id], emb, sub_guided_hint) + num_tasks_weight = task_scale[sub_hint_id] + guided_hint = num_tasks_weight * guided_hint + guided_hint_final_list.append(guided_hint) + del guided_hint - guided_hint = self.input_hint_block_share(guided_hint, emb, context) - - guided_hint = modulated_conv2d(guided_hint, self.input_hint_block_zeroconv_1[0].weight, self.task_id_layernet_zeroconv_1(task_id_emb).repeat(BS_Real, 1).detach(), padding=1) - guided_hint += self.input_hint_block_zeroconv_1[0].bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) + if len(all_tasks) > 1: + guided_hint_weight_sumed = guided_hint_final_list[0] + for i in range(1, len(guided_hint_final_list)): + guided_hint_weight_sumed = guided_hint_weight_sumed + guided_hint_final_list[i] + guided_hint = guided_hint_weight_sumed + else: + guided_hint0 = self.input_hint_block_list_moe[task_id](hint, emb, context) + guided_hint = self.deal_hint(context, BS_Real, task_id_emb, emb,guided_hint0) + + # split_switch = 1 # for debug ------------------------ + all_outputs = [] + if split_switch and len(all_tasks) > 1: # 是否分次计算,然后合并----?---- + for sub_hint_id, task in enumerate(all_tasks): + task_id = self.tasks_to_id[task] # 1 3 + sub_guided_hint = self.input_hint_block_list_moe[task_id](hint[sub_hint_id], emb, context) + guided_hint=self.deal_hint(context, BS_Real, all_task_id_emb[sub_hint_id], emb, sub_guided_hint) + num_tasks_weight = task_scale[sub_hint_id] + guided_hint = num_tasks_weight * guided_hint + # for guided_hint in guided_hint_final_list: + # do the fellowing calculate + outs = [] + h = x.type(self.dtype) + for module, zero_conv, task_hyperlayer in zip(self.input_blocks, self.zero_convs, self.task_id_layernet): + if guided_hint is not None: + h = module(h, emb, context) + # print(h.shape) + # print(guided_hint.shape) + try: + h += guided_hint + except RuntimeError: + # pdb.set_trace() + continue + guided_hint = None + else: + h = module(h, emb, context) + + # outs.append(modulated_conv2d(h, zero_conv[0].weight, task_hyperlayer(task_id_emb).repeat(BS_Real, 1).detach()) + zero_conv[0].bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)) + outs.append(modulated_conv2d(h, zero_conv[0].weight, task_hyperlayer(all_task_id_emb[sub_hint_id]).repeat(BS_Real, 1).detach()) + zero_conv[0].bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)) + + h = self.middle_block(h, emb, context) + outs.append(self.middle_block_out(h, emb, context)) + all_outputs.append(outs) + # final_outputs + final_outputs = all_outputs[0] + for sub_h in all_outputs[0:]: # len(sub_h) 13 : [torch.Size([2, 320, 64, 96]), torch.Size([2, 1280, 8, 12])] * 13 + for sub_index in range(len(sub_h)): + final_outputs[sub_index] = torch.add(final_outputs[sub_index], sub_h[sub_index]) + # final_outputs + return final_outputs outs = [] h = x.type(self.dtype) for module, zero_conv, task_hyperlayer in zip(self.input_blocks, self.zero_convs, self.task_id_layernet): if guided_hint is not None: h = module(h, emb, context) + # print(h.shape) + # print(guided_hint.shape) try: h += guided_hint except RuntimeError: - pdb.set_trace() + # pdb.set_trace() + continue guided_hint = None else: h = module(h, emb, context) @@ -402,6 +488,16 @@ def forward(self, x, hint, timesteps, context, **kwargs): return outs + def deal_hint(self, context, BS_Real, task_id_emb, emb, guided_hint): + guided_hint = modulated_conv2d(guided_hint, self.input_hint_block_zeroconv_0[0].weight, self.task_id_layernet_zeroconv_0(task_id_emb).repeat(BS_Real, 1).detach(), padding=1) + guided_hint = guided_hint + self.input_hint_block_zeroconv_0[0].bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) + + guided_hint = self.input_hint_block_share(guided_hint, emb, context) + + guided_hint = modulated_conv2d(guided_hint, self.input_hint_block_zeroconv_1[0].weight, self.task_id_layernet_zeroconv_1(task_id_emb).repeat(BS_Real, 1).detach(), padding=1) + guided_hint = guided_hint + self.input_hint_block_zeroconv_1[0].bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) + return guided_hint + class ControlLDM(LatentDiffusion): @@ -458,7 +554,12 @@ def apply_model(self, x_noisy, t, cond, *args, **kwargs): if cond['c_concat'] is None: eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control) else: - control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt, task=task_name) + # control = self.control_model(x=x_noisy, hint=[cond['c_concat'][0],cond['c_concat'][1],cond['c_concat'][2]], + # timesteps=t, context=cond_txt, task=task_name) + if len(cond['c_concat']) == 1: + control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt, task=task_name) + else: + control = self.control_model(x=x_noisy, hint=list(cond['c_concat']), timesteps=t, context=cond_txt, task=task_name) control = [c * scale for c, scale in zip(control, self.control_scales)] eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)