Skip to content

Commit

Permalink
chore(learnergy): Adds annotated typing.
Browse files Browse the repository at this point in the history
  • Loading branch information
gugarosa committed Apr 26, 2022
1 parent c681323 commit 6dc723b
Show file tree
Hide file tree
Showing 22 changed files with 906 additions and 785 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ repos:
rev: 5.10.1
hooks:
- id: isort
exclude: learnergy/models/bernoulli/__init__.py
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
Expand Down
48 changes: 29 additions & 19 deletions learnergy/core/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Dataset-related classes.
"""

from typing import Optional, Tuple

import numpy as np
import torch

Expand All @@ -13,13 +15,15 @@
class Dataset(torch.utils.data.Dataset):
"""A custom dataset class, inherited from PyTorch's dataset."""

def __init__(self, data, targets, transform=None):
def __init__(
self, data: np.array, targets: np.array, transform: Optional[callable] = None
) -> None:
"""Initialization method.
Args:
data (np.array): An n-dimensional array containing the data.
targets (np.array): An 1-dimensional array containing the data's labels.
transform (callable): Optional transform to be applied over a sample.
data: An n-dimensional array containing the data.
targets: An 1-dimensional array containing the data's labels.
transform: Optional transform to be applied over a sample.
"""

Expand All @@ -43,45 +47,46 @@ def __init__(self, data, targets, transform=None):
)

@property
def data(self):
"""np.array: An n-dimensional array containing the data."""
def data(self) -> np.array:
"""An n-dimensional array containing the data."""

return self._data

@data.setter
def data(self, data):

def data(self, data: np.array) -> None:
self._data = data

@property
def targets(self):
"""np.array: An 1-dimensional array containing the data's labels."""
def targets(self) -> np.array:
"""An 1-dimensional array containing the data's labels."""

return self._targets

@targets.setter
def targets(self, targets):

def targets(self, targets: np.array) -> None:
self._targets = targets

@property
def transform(self):
"""callable: Optional transform to be applied over a sample."""
def transform(self) -> callable:
"""Optional transform to be applied over a sample."""

return self._transform

@transform.setter
def transform(self, transform):
def transform(self, transform: callable) -> None:
if not (hasattr(transform, "__call__") or transform is None):
raise e.TypeError("`transform` should be a callable or None")

self._transform = transform

def __getitem__(self, idx):
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""A private method that will be the base for PyTorch's iterator getting a new sample.
Args:
idx (int): The idx of desired sample.
idx: The idx of desired sample.
Returns:
(Tuple[torch.Tensor, torch.Tensor]): Data and label tensors.
"""

Expand All @@ -93,7 +98,12 @@ def __getitem__(self, idx):

return x, y

def __len__(self):
"""A private method that will be the base for PyTorch's iterator getting dataset's length."""
def __len__(self) -> int:
"""A private method that will be the base for PyTorch's iterator getting dataset's length.
Returns:
(int): Length of dataset.
"""

return len(self.data)
21 changes: 11 additions & 10 deletions learnergy/core/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Standard model-related implementation.
"""

from typing import Any, Dict, Optional

import torch

import learnergy.utils.exception as e
Expand All @@ -17,11 +19,11 @@ class Model(torch.nn.Module):
"""

def __init__(self, use_gpu=False):
def __init__(self, use_gpu: Optional[bool] = False) -> None:
"""Initialization method.
Args:
use_gpu (bool): Whether GPU should be used or not.
use_gpu: Whether GPU should be used or not.
"""

Expand All @@ -43,30 +45,29 @@ def __init__(self, use_gpu=False):
logger.debug("Device: %s.", self.device)

@property
def device(self):
"""str: Indicates which device is being used for computation."""
def device(self) -> str:
"""Indicates which device is being used for computation."""

return self._device

@device.setter
def device(self, device):
def device(self, device: str) -> None:
if device not in ["cpu", "cuda"]:
raise e.TypeError("`device` should be `cpu` or `cuda`")

self._device = device

@property
def history(self):
"""dict: Dictionary containing historical values from the model."""
def history(self) -> Dict[str, Any]:
"""Dictionary containing historical values from the model."""

return self._history

@history.setter
def history(self, history):

def history(self, history: Dict[str, Any]) -> None:
self._history = history

def dump(self, **kwargs):
def dump(self, **kwargs) -> None:
"""Dumps any amount of keyword documents to lists in the history property."""

for k, v in kwargs.items():
Expand Down
9 changes: 5 additions & 4 deletions learnergy/math/metrics.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
"""Metrics-related mathematical functions.
"""

import torch
from skimage.metrics import structural_similarity as ssim

from learnergy.utils import logging

logger = logging.get_logger(__name__)


def calculate_ssim(v, x):
def calculate_ssim(v: torch.Tensor, x: torch.Tensor) -> float:
"""Calculates the structural similarity of images.
Args:
v (torch.Tensor): Reconstructed images.
x (torch.Tensor): Original images.
v: Reconstructed images.
x: Original images.
Returns:
The structural similarity between input images.
(float): Structural similarity between input images.
"""

Expand Down
8 changes: 5 additions & 3 deletions learnergy/math/scale.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
"""Scaling-related mathematical functions.
"""

import numpy as np

import learnergy.utils.constants as c


def unitary_scale(x):
def unitary_scale(x: np.array) -> np.array:
"""Scales an array between 0 and 1.
Args:
x (array): A numpy array to be scaled.
x: A numpy array to be scaled.
Returns:
The scaled array.
(np.array): Scaled array.
"""

Expand Down
2 changes: 1 addition & 1 deletion learnergy/models/bernoulli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""A package contaning bernoulli-based models (networks) for all common learnergy modules.
"""

from learnergy.models.bernoulli.rbm import RBM
from learnergy.models.bernoulli.conv_rbm import ConvRBM
from learnergy.models.bernoulli.discriminative_rbm import (
DiscriminativeRBM,
HybridDiscriminativeRBM,
)
from learnergy.models.bernoulli.dropout_rbm import DropConnectRBM, DropoutRBM
from learnergy.models.bernoulli.e_dropout_rbm import EDropoutRBM
from learnergy.models.bernoulli.rbm import RBM
Loading

0 comments on commit 6dc723b

Please sign in to comment.