Skip to content

Commit

Permalink
use relative imports everywhere (ML4GW#190)
Browse files Browse the repository at this point in the history
  • Loading branch information
ravioli1369 authored Jan 25, 2025
1 parent ded9196 commit 00e0f22
Show file tree
Hide file tree
Showing 25 changed files with 44 additions and 49 deletions.
2 changes: 1 addition & 1 deletion ml4gw/dataloading/chunked_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from ml4gw.types import WaveformTensor
from ..types import WaveformTensor


class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
Expand Down
2 changes: 1 addition & 1 deletion ml4gw/dataloading/hdf5_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import torch

from ml4gw.types import WaveformTensor
from ..types import WaveformTensor


class ContiguousHdf5Warning(Warning):
Expand Down
2 changes: 1 addition & 1 deletion ml4gw/dataloading/in_memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jaxtyping import Float
from torch import Tensor

from ml4gw.utils.slicing import slice_kernels
from ..utils.slicing import slice_kernels


class InMemoryDataset(torch.utils.data.IterableDataset):
Expand Down
7 changes: 4 additions & 3 deletions ml4gw/gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from jaxtyping import Float
from torch import Tensor

from ml4gw.constants import C
from ml4gw.types import (
from ml4gw.utils.interferometer import InterferometerGeometry

from .constants import C
from .types import (
BatchTensor,
NetworkDetectorTensors,
NetworkVertices,
Expand All @@ -26,7 +28,6 @@
VectorGeometry,
WaveformTensor,
)
from ml4gw.utils.interferometer import InterferometerGeometry


def outer(x: VectorGeometry, y: VectorGeometry) -> TensorGeometry:
Expand Down
2 changes: 1 addition & 1 deletion ml4gw/nn/autoencoder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch import Tensor

from ml4gw.nn.autoencoder.skip_connection import SkipConnection
from .skip_connection import SkipConnection


class Autoencoder(torch.nn.Module):
Expand Down
6 changes: 3 additions & 3 deletions ml4gw/nn/autoencoder/convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import torch
from torch import Tensor

from ml4gw.nn.autoencoder.base import Autoencoder
from ml4gw.nn.autoencoder.skip_connection import SkipConnection
from ml4gw.nn.autoencoder.utils import match_size
from .base import Autoencoder
from .skip_connection import SkipConnection
from .utils import match_size

Module = Callable[[...], torch.nn.Module]

Expand Down
2 changes: 1 addition & 1 deletion ml4gw/nn/autoencoder/skip_connection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch import Tensor

from ml4gw.nn.autoencoder.utils import match_size
from .utils import match_size


class SkipConnection(torch.nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion ml4gw/nn/resnet/resnet_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch.nn as nn
from torch import Tensor

from ml4gw.nn.norm import GroupNorm1DGetter, NormLayer
from ..norm import GroupNorm1DGetter, NormLayer


def convN(
Expand Down
2 changes: 1 addition & 1 deletion ml4gw/nn/resnet/resnet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.nn as nn
from torch import Tensor

from ml4gw.nn.norm import GroupNorm2DGetter, NormLayer
from ..norm import GroupNorm2DGetter, NormLayer


def convN(
Expand Down
2 changes: 1 addition & 1 deletion ml4gw/nn/streaming/online_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jaxtyping import Float
from torch import Tensor

from ml4gw.utils.slicing import unfold_windows
from ...utils.slicing import unfold_windows


class OnlineAverager(torch.nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion ml4gw/nn/streaming/snapshotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jaxtyping import Float
from torch import Tensor

from ml4gw.utils.slicing import unfold_windows
from ...utils.slicing import unfold_windows


class Snapshotter(torch.nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion ml4gw/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from jaxtyping import Float
from torch import Tensor

from ml4gw.types import (
from .types import (
FrequencySeries1to3d,
PSDTensor,
TimeSeries1to3d,
Expand Down
4 changes: 2 additions & 2 deletions ml4gw/transforms/pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from jaxtyping import Float
from torch import Tensor

from ml4gw.types import TimeSeries1to3d
from ml4gw.utils.slicing import unfold_windows
from ..types import TimeSeries1to3d
from ..utils.slicing import unfold_windows


class ShiftedPearsonCorrelation(torch.nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions ml4gw/transforms/qtransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from jaxtyping import Float, Int
from torch import Tensor

from ml4gw.transforms.spline_interpolation import SplineInterpolate
from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d
from ..types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d
from .spline_interpolation import SplineInterpolate

"""
All based on https://github.com/gwpy/gwpy/blob/v3.0.8/gwpy/signal/qtransform.py
Expand Down
2 changes: 1 addition & 1 deletion ml4gw/transforms/scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jaxtyping import Float
from torch import Tensor

from ml4gw.transforms.transform import FittableTransform
from .transform import FittableTransform


class ChannelWiseScaler(FittableTransform):
Expand Down
6 changes: 3 additions & 3 deletions ml4gw/transforms/snr_rescaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import torch

from ml4gw.gw import compute_network_snr
from ml4gw.transforms.transform import FittableSpectralTransform
from ml4gw.types import BatchTensor, TimeSeries2d, WaveformTensor
from ..gw import compute_network_snr
from ..types import BatchTensor, TimeSeries2d, WaveformTensor
from .transform import FittableSpectralTransform


class SnrRescaler(FittableSpectralTransform):
Expand Down
4 changes: 2 additions & 2 deletions ml4gw/transforms/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from jaxtyping import Float
from torch import Tensor

from ml4gw.spectral import fast_spectral_density, spectral_density
from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d
from ..spectral import fast_spectral_density, spectral_density
from ..types import FrequencySeries1to3d, TimeSeries1to3d


class SpectralDensity(torch.nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion ml4gw/transforms/spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import Tensor
from torchaudio.transforms import Spectrogram

from ml4gw.types import TimeSeries3d
from ..types import TimeSeries3d


class MultiResolutionSpectrogram(torch.nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions ml4gw/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import torch

from ml4gw.spectral import spectral_density
from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d
from ..spectral import spectral_density
from ..types import FrequencySeries1to3d, TimeSeries1to3d


class FittableTransform(torch.nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions ml4gw/transforms/waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from jaxtyping import Float
from torch import Tensor

from ml4gw import gw
from ml4gw.types import BatchTensor
from .. import gw
from ..types import BatchTensor


# TODO: should these live in ml4gw.waveforms submodule?
Expand Down
6 changes: 3 additions & 3 deletions ml4gw/transforms/whitening.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

import torch

from ml4gw import spectral
from ml4gw.transforms.transform import FittableSpectralTransform
from ml4gw.types import (
from .. import spectral
from ..types import (
FrequencySeries1d,
FrequencySeries1to3d,
TimeSeries1d,
TimeSeries3d,
)
from .transform import FittableSpectralTransform


class Whiten(torch.nn.Module):
Expand Down
7 changes: 1 addition & 6 deletions ml4gw/utils/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@
from torch import Tensor
from torch.nn.functional import unfold

from ml4gw.types import (
TimeSeries1d,
TimeSeries1to3d,
TimeSeries2d,
TimeSeries3d,
)
from ..types import TimeSeries1d, TimeSeries1to3d, TimeSeries2d, TimeSeries3d

BatchTimeSeriesTensor = Union[Float[Tensor, "batch time"], TimeSeries3d]

Expand Down
5 changes: 2 additions & 3 deletions ml4gw/waveforms/cbc/phenom_d.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import torch
from jaxtyping import Float

from ml4gw.constants import MTSUN_SI, PI
from ml4gw.types import BatchTensor, FrequencySeries1d

from ...constants import MTSUN_SI, PI
from ...types import BatchTensor, FrequencySeries1d
from .phenom_d_data import QNMData_a, QNMData_fdamp, QNMData_fring
from .taylorf2 import TaylorF2

Expand Down
8 changes: 4 additions & 4 deletions ml4gw/waveforms/cbc/taylorf2.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch
from jaxtyping import Float

from ml4gw.constants import MPC_SEC, MTSUN_SI, PI
from ml4gw.constants import EulerGamma as GAMMA
from ml4gw.types import BatchTensor, FrequencySeries1d
from ml4gw.waveforms.conversion import chirp_mass_and_mass_ratio_to_components
from ...constants import MPC_SEC, MTSUN_SI, PI
from ...constants import EulerGamma as GAMMA
from ...types import BatchTensor, FrequencySeries1d
from ..conversion import chirp_mass_and_mass_ratio_to_components


class TaylorF2(torch.nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions ml4gw/waveforms/conversion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from ml4gw.constants import MTSUN_SI, PI
from ml4gw.types import BatchTensor
from ..constants import MTSUN_SI, PI
from ..types import BatchTensor


def rotate_z(angle: BatchTensor, x, y, z):
Expand Down

0 comments on commit 00e0f22

Please sign in to comment.