From c258266ad50c488ea194c2dba61d7c06ad7557de Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Fri, 9 Sep 2022 12:11:22 +0200 Subject: [PATCH] fix: transform applied after resampling, refactor mono/streo transform name --- README.md | 8 ++++---- audio_data_pytorch/datasets/wav_dataset.py | 5 +++-- audio_data_pytorch/transforms/__init__.py | 2 +- audio_data_pytorch/transforms/all.py | 10 +++++----- .../transforms/{overlap_channels.py => mono.py} | 2 +- setup.py | 2 +- 6 files changed, 15 insertions(+), 14 deletions(-) rename audio_data_pytorch/transforms/{overlap_channels.py => mono.py} (85%) diff --git a/README.md b/README.md index a62e319..be544e0 100644 --- a/README.md +++ b/README.md @@ -114,8 +114,8 @@ random_crop = RandomCrop(size=22050*2) # Crop 2 seconds at 22050 Hz from a rando from audio_data_pytorch import Resample resample = Resample(source=48000, target=22050), # Resamples from 48kHz to 22kHz -from audio_data_pytorch import OverlapChannels -overlap = OverlapChannels() # Overap channels by sum (C, N) -> (1, N) +from audio_data_pytorch import Mono +overlap = Mono() # Overap channels by sum to get mono soruce (C, N) -> (1, N) from audio_data_pytorch import Stereo stereo = Stereo() # Duplicate channels (1, N) -> (2, N) or (2, N) -> (2, N) @@ -138,7 +138,7 @@ transform = AllTransform( random_crop_size: Optional[int] = None, loudness: Optional[int] = None, scale: Optional[float] = None, - overlap_channels: bool = False, - use_stereo: bool = False, + mono: bool = False, + stereo: bool = False, ) ``` diff --git a/audio_data_pytorch/datasets/wav_dataset.py b/audio_data_pytorch/datasets/wav_dataset.py index db347bb..156542c 100644 --- a/audio_data_pytorch/datasets/wav_dataset.py +++ b/audio_data_pytorch/datasets/wav_dataset.py @@ -36,14 +36,15 @@ def __getitem__( ) -> Union[Tensor, Tuple[Tensor, Tensor]]: idx = idx.tolist() if torch.is_tensor(idx) else idx # type: ignore waveform, sample_rate = torchaudio.load(self.wavs[idx]) - if self.transforms: - waveform = self.transforms(waveform) if self.sample_rate and sample_rate != self.sample_rate: waveform = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=self.sample_rate )(waveform) + if self.transforms: + waveform = self.transforms(waveform) + return waveform def __len__(self) -> int: diff --git a/audio_data_pytorch/transforms/__init__.py b/audio_data_pytorch/transforms/__init__.py index ca74873..c93663e 100644 --- a/audio_data_pytorch/transforms/__init__.py +++ b/audio_data_pytorch/transforms/__init__.py @@ -1,7 +1,7 @@ from .all import AllTransform from .crop import Crop from .loudness import Loudness -from .overlap_channels import OverlapChannels +from .mono import Mono from .randomcrop import RandomCrop from .resample import Resample from .scale import Scale diff --git a/audio_data_pytorch/transforms/all.py b/audio_data_pytorch/transforms/all.py index 6436e0c..30c3c39 100644 --- a/audio_data_pytorch/transforms/all.py +++ b/audio_data_pytorch/transforms/all.py @@ -5,7 +5,7 @@ from ..utils import exists from .crop import Crop from .loudness import Loudness -from .overlap_channels import OverlapChannels +from .mono import Mono from .randomcrop import RandomCrop from .resample import Resample from .scale import Scale @@ -21,8 +21,8 @@ def __init__( random_crop_size: Optional[int] = None, loudness: Optional[int] = None, scale: Optional[float] = None, - use_stereo: bool = False, - overlap_channels: bool = False, + stereo: bool = False, + mono: bool = False, ): super().__init__() @@ -38,8 +38,8 @@ def __init__( else nn.Identity(), RandomCrop(random_crop_size) if exists(random_crop_size) else nn.Identity(), Crop(crop_size) if exists(crop_size) else nn.Identity(), - OverlapChannels() if overlap_channels else nn.Identity(), - Stereo() if use_stereo else nn.Identity(), + Mono() if mono else nn.Identity(), + Stereo() if stereo else nn.Identity(), Loudness(sampling_rate=target_rate, target=loudness) # type: ignore if exists(loudness) else nn.Identity(), diff --git a/audio_data_pytorch/transforms/overlap_channels.py b/audio_data_pytorch/transforms/mono.py similarity index 85% rename from audio_data_pytorch/transforms/overlap_channels.py rename to audio_data_pytorch/transforms/mono.py index 91fc053..f416fa4 100644 --- a/audio_data_pytorch/transforms/overlap_channels.py +++ b/audio_data_pytorch/transforms/mono.py @@ -2,7 +2,7 @@ from torch import Tensor, nn -class OverlapChannels(nn.Module): +class Mono(nn.Module): """Overlaps all channels into one""" def forward(self, x: Tensor) -> Tensor: diff --git a/setup.py b/setup.py index 6ee3e69..d8870f6 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-data-pytorch", packages=find_packages(exclude=[]), - version="0.0.9", + version="0.0.10", license="MIT", description="Audio Data - PyTorch", long_description_content_type="text/markdown",