From d15ef3679aa572453fc1d44f66947de71fa85e72 Mon Sep 17 00:00:00 2001 From: SWivid Date: Mon, 21 Oct 2024 17:55:58 +0800 Subject: [PATCH] fix address #191 --- model/backbones/dit.py | 2 +- model/backbones/unett.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/model/backbones/dit.py b/model/backbones/dit.py index 9ff53513..b8e6dc3f 100644 --- a/model/backbones/dit.py +++ b/model/backbones/dit.py @@ -45,9 +45,9 @@ def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): self.extra_modeling = False def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 - batch, text_len = text.shape[0], text.shape[1] text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens + batch, text_len = text.shape[0], text.shape[1] text = F.pad(text, (0, seq_len - text_len), value=0) if drop_text: # cfg for text diff --git a/model/backbones/unett.py b/model/backbones/unett.py index c4ce2c64..ac1d3d35 100644 --- a/model/backbones/unett.py +++ b/model/backbones/unett.py @@ -48,9 +48,9 @@ def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): self.extra_modeling = False def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 - batch, text_len = text.shape[0], text.shape[1] text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens + batch, text_len = text.shape[0], text.shape[1] text = F.pad(text, (0, seq_len - text_len), value=0) if drop_text: # cfg for text