Skip to content

Commit

Permalink
fix: transform applied after resampling, refactor mono/streo transfor…
Browse files Browse the repository at this point in the history
…m name
  • Loading branch information
flavioschneider committed Sep 9, 2022
1 parent aca4c1b commit c258266
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 14 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ random_crop = RandomCrop(size=22050*2) # Crop 2 seconds at 22050 Hz from a rando
from audio_data_pytorch import Resample
resample = Resample(source=48000, target=22050), # Resamples from 48kHz to 22kHz

from audio_data_pytorch import OverlapChannels
overlap = OverlapChannels() # Overap channels by sum (C, N) -> (1, N)
from audio_data_pytorch import Mono
overlap = Mono() # Overap channels by sum to get mono soruce (C, N) -> (1, N)

from audio_data_pytorch import Stereo
stereo = Stereo() # Duplicate channels (1, N) -> (2, N) or (2, N) -> (2, N)
Expand All @@ -138,7 +138,7 @@ transform = AllTransform(
random_crop_size: Optional[int] = None,
loudness: Optional[int] = None,
scale: Optional[float] = None,
overlap_channels: bool = False,
use_stereo: bool = False,
mono: bool = False,
stereo: bool = False,
)
```
5 changes: 3 additions & 2 deletions audio_data_pytorch/datasets/wav_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ def __getitem__(
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
idx = idx.tolist() if torch.is_tensor(idx) else idx # type: ignore
waveform, sample_rate = torchaudio.load(self.wavs[idx])
if self.transforms:
waveform = self.transforms(waveform)

if self.sample_rate and sample_rate != self.sample_rate:
waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=self.sample_rate
)(waveform)

if self.transforms:
waveform = self.transforms(waveform)

return waveform

def __len__(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion audio_data_pytorch/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .all import AllTransform
from .crop import Crop
from .loudness import Loudness
from .overlap_channels import OverlapChannels
from .mono import Mono
from .randomcrop import RandomCrop
from .resample import Resample
from .scale import Scale
Expand Down
10 changes: 5 additions & 5 deletions audio_data_pytorch/transforms/all.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..utils import exists
from .crop import Crop
from .loudness import Loudness
from .overlap_channels import OverlapChannels
from .mono import Mono
from .randomcrop import RandomCrop
from .resample import Resample
from .scale import Scale
Expand All @@ -21,8 +21,8 @@ def __init__(
random_crop_size: Optional[int] = None,
loudness: Optional[int] = None,
scale: Optional[float] = None,
use_stereo: bool = False,
overlap_channels: bool = False,
stereo: bool = False,
mono: bool = False,
):
super().__init__()

Expand All @@ -38,8 +38,8 @@ def __init__(
else nn.Identity(),
RandomCrop(random_crop_size) if exists(random_crop_size) else nn.Identity(),
Crop(crop_size) if exists(crop_size) else nn.Identity(),
OverlapChannels() if overlap_channels else nn.Identity(),
Stereo() if use_stereo else nn.Identity(),
Mono() if mono else nn.Identity(),
Stereo() if stereo else nn.Identity(),
Loudness(sampling_rate=target_rate, target=loudness) # type: ignore
if exists(loudness)
else nn.Identity(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch import Tensor, nn


class OverlapChannels(nn.Module):
class Mono(nn.Module):
"""Overlaps all channels into one"""

def forward(self, x: Tensor) -> Tensor:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-data-pytorch",
packages=find_packages(exclude=[]),
version="0.0.9",
version="0.0.10",
license="MIT",
description="Audio Data - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit c258266

Please sign in to comment.