Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The training of the model is very slow and the GPU power is very low. #9

Open
1229282331 opened this issue Dec 30, 2024 · 3 comments
Open

Comments

@1229282331
Copy link

When i use this function as loss-function, my training is so slow and the power of gpu is only 70W/220W.

class loss_pesq(torch.nn.Module):
    def __init__(self, sample_rate):
        super(loss_pesq, self).__init__()
        self.pesq = PesqLoss(0.5, sample_rate=sample_rate)
        self.windows = torch.hann_window(512).pow(0.5).to("cuda")

    def forward(self, est, clean):
        """ inputs: waveform """
        est = torch.istft(est[..., 0] + 1j*est[..., 1], n_fft=512, hop_length=256, win_length=512, window=self.windows)
        clean = torch.istft(clean[..., 0] + 1j*clean[..., 1], n_fft=512, hop_length=256, win_length=512, window=self.windows)

        # data_len = min(est.shape[-1], clean.shape[-1])
        # est = est[..., : data_len]
        # clean = clean[...,: data_len]

        return self.pesq(clean, est).mean()

image
How to improve the GPU computing efficiency and accelerate training?

@nils-werner
Copy link
Contributor

Can you run your code through the profiler to see where the bottleneck lies?

@1229282331
Copy link
Author

@nils-werner hi, i found that maybe the align_level() and preemphasize() work inefficiently. Among them, lfilter seems to be the culprit.It may be because lfilter requires a significant amount of GPU startup time.

@1229282331
Copy link
Author

Now I am sure the problem lies with lfilter. For smaller batch sizes (like [2, 480000]), the parallelism is poor, and GPU computation efficiency is low. I might consider increasing the batch size and truncating the audio (e.g., [8, 160000]) to improve GPU efficiency. Additionally, I have tried making simple modifications to the functions align_level() and preemphasize() that involve calling lfilter:

def align_level(
        self, signal: TensorType["batch", "sample"]
    ) -> TensorType["batch", "sample"]:
        """Align power to 10**7 for band 325 to 3.25kHz

        Parameters
        ----------
        signal : TensorType["batch", "sample"]
            Input time signal with size [batch, sample]

        Returns
        -------
        TensorType["batch", "sample"]
            Tensor containing the scaled time signal
        """
        batch_size = signal.shape[0]
        # filtered_signal = lfilter(signal, self.power_filter[1], self.power_filter[0], clamp=False)
        ##################################################
        filtered_signal = lfilter(signal.reshape(batch_size*8, -1), self.power_filter[1], self.power_filter[0], clamp=False).reshape(batch_size, -1)
        ##################################################

        # calculate power with weird bugs in reference implementation
        power = (
            (filtered_signal**2).sum(dim=1, keepdim=True)
            / (filtered_signal.shape[1] + 5120)
            / 1.04684
        )

        # align power
        signal = signal * (10**7 / power).sqrt()

        return signal
def preemphasize(
        self, signal: TensorType["batch", "sample"]
    ) -> TensorType["batch", "sample"]:
        """Pre-empasize a signal

        This pre-emphasize filter is also applied in the reference implementation. The filter
        coefficients are taken from the reference.

        Parameters
        ----------
        signal : TensorType["batch", "sample"]
            Input time signal with size [batch, sample]

        Returns
        -------
        TensorType["batch", "sample"]
            Tensor containing the pre-emphasized signal
        """
        batch_size = signal.shape[0]
        emp = torch.linspace(0, 15, 16, device=signal.device)[1:] / 16.0
        signal[:, :15] *= emp
        signal[:, -15:] *= torch.flip(emp, dims=(0,))

        # signal = lfilter(signal, self.pre_filter[1], self.pre_filter[0], clamp=False)
        ##################################################
        signal = lfilter(signal.reshape(batch_size*8, -1), self.pre_filter[1], self.pre_filter[0], clamp=False).reshape(batch_size, -1)
        ##################################################

        return signal
def raw(
        self, ref: TensorType["batch", "sample"], deg: TensorType["batch", "sample"]
    ) -> Tuple[TensorType["batch", "sample"], TensorType["batch", "sample"]]:
        """Calculate symmetric and asymmetric distances"""
        deg, ref = torch.atleast_2d(deg), torch.atleast_2d(ref)
        batch_size = deg.shape[0]
        
        # ......

        # ref, deg = self.align_level(ref), self.align_level(deg)
        # ref, deg = self.preemphasize(ref), self.preemphasize(deg)

        ref_deg = self.align_level(torch.cat((ref, deg), dim=0))
        ref, deg = ref_deg[:batch_size, ...], ref_deg[batch_size:, ...]
        ref_deg = self.preemphasize(torch.cat((ref, deg), dim=0))
        ref, deg = ref_deg[:batch_size, ...], ref_deg[batch_size:, ...]

        # ......

Although the initial state information is lost due to the IIR filtering of the crudely folded time series, the error is relatively small. Do you have any good ways to improve lfilter parallelism through this external means?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants