Skip to content
This repository was archived by the owner on Nov 29, 2023. It is now read-only.

[wip] pydocstyle #104

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .github/workflows/linters.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: Lint Python

on: [push, pull_request]

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8, 3.9]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pydocstyle
- name: Docstyle linting
run: |
pydocstyle --convention=google --add-ignore=D200,D210,D212,D415
1 change: 1 addition & 0 deletions satflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
"""satflow package"""
from .version import __version__
1 change: 1 addition & 0 deletions satflow/baseline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Evaluation of baseline models"""
11 changes: 11 additions & 0 deletions satflow/baseline/optical_flow.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Evaluation of baseline models"""
import cv2
from satflow.data.datasets import OpticalFlowDataset, SatFlowDataset
import webdataset as wds
Expand All @@ -7,6 +8,7 @@


def load_config(config_file):
"""Load a config file from disk"""
with open(config_file, "r") as cfg:
return yaml.load(cfg, Loader=yaml.FullLoader)["config"]

Expand All @@ -21,6 +23,15 @@ def load_config(config_file):


def warp_flow(img, flow):
"""
Get the previous image by inverting the optical flow and applying it to the current image

Args:
img: the current image
flow: the optical flow

Returns: the resulting image
"""
h, w = flow.shape[:2]
flow = -flow
flow[:, :, 0] += np.arange(w)
Expand Down
1 change: 1 addition & 0 deletions satflow/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Core utility functions"""
41 changes: 29 additions & 12 deletions satflow/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Core utility functions"""
import logging
import typing as Dict
from nowcasting_dataset.config.load import load_yaml_configuration
Expand All @@ -6,12 +7,23 @@


def load_config(file_path: str) -> Dict:
"""Load yaml config file from file path"""
with open(file_path, "r") as f:
config = yaml.load(f)
return config


def make_logger(name: str, level=logging.DEBUG) -> logging.Logger:
"""
Get a named logger at a specified level

Args:
name: name of the logger
level: level of the logger. Default is logging.DEBUG

Returns:
The logger
"""
logger = logging.getLogger(name)
logger.setLevel(level=level)
return logger
Expand All @@ -30,7 +42,6 @@ def make_logger(name: str, level=logging.DEBUG) -> logging.Logger:

def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
"""Initializes multi-GPU-friendly python logger."""

logger = logging.getLogger(name)
logger.setLevel(level)

Expand All @@ -43,19 +54,21 @@ def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:


def extras(config: DictConfig) -> None:
"""A couple of optional utilities, controlled by main config file:
- disabling warnings
- easier access to debug mode
- forcing debug friendly configuration
- forcing multi-gpu friendly configuration
- Ensure correct number of timesteps/etc for all of them
"""
A couple of optional utilities, controlled by main config file

Utilities include
- disabling warnings
- easier access to debug mode
- forcing debug friendly configuration
- forcing multi-gpu friendly configuration
- Ensure correct number of timesteps/etc for all of them

Modifies DictConfig in place.

Args:
config (DictConfig): Configuration composed by Hydra.
"""

log = get_logger()

# enable adding new keys to config
Expand Down Expand Up @@ -150,10 +163,9 @@ def print_config(
Args:
config (DictConfig): Configuration composed by Hydra.
fields (Sequence[str], optional): Determines which main fields from config will
be printed and in what order.
be printed and in what order.
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
"""

style = "dim"
tree = rich.tree.Tree(":gear: CONFIG", style=style, guide_style=style)

Expand All @@ -180,12 +192,17 @@ def log_hyperparameters(
model: pl.LightningModule,
trainer: pl.Trainer,
) -> None:
"""This method controls which parameters from Hydra config are saved by Lightning loggers.
"""
This method controls which parameters from Hydra config are saved by Lightning loggers.

Additionaly saves:
- number of trainable model parameters
"""

Args:
config (DictConfig): Configuration composed by Hydra.
model (pl.LightningModule): the model with parameters to save
trainer (pl.Trainer): the trainer with hyperparams
"""
hparams = {}

# choose which parts of hydra config will be saved to loggers
Expand Down
1 change: 1 addition & 0 deletions satflow/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""satflow data"""
29 changes: 26 additions & 3 deletions satflow/data/datamodules.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""A DataModule that encapsulates the steps to process the data"""
import os
from nowcasting_dataset.dataset.datasets import worker_init_fn
from nowcasting_dataset.config.load import load_yaml_configuration
Expand Down Expand Up @@ -27,6 +28,8 @@

class SatFlowDataModule(LightningDataModule):
"""
A SatFlow DataModule

Example of LightningDataModule for NETCDF dataset.
A DataModule implements 5 key methods:
- prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode)
Expand Down Expand Up @@ -67,7 +70,21 @@ def __init__(
forecast_minutes: Optional[int] = None,
):
"""
fake_data: random data is created and used instead. This is useful for testing
Initialize a satflow DataModule

Args:
temp_path: a file path to store temporary data
n_train_data: default is 24900
n_val_data: default is 1000
cloud: name of cloud provider. Default is "aws".
num_workers: default is 8
pin_memory: default is true
configuration_filename: a file path
fake_data: random data is created and used instead. This is useful for testing.
Default is false.
required_keys: tuple or list of keys required in the example for it to be considered usable
history_minutes: how many past minutes of data to use, if subsetting the batch. Default is None.
forecast_minutes: how many future minutes of data to use, if reducing the amount of forecast time. Default is None.
"""
super().__init__()

Expand Down Expand Up @@ -95,6 +112,7 @@ def __init__(
)

def train_dataloader(self):
"""A data loader for the training data"""
if self.fake_data:
train_dataset = FakeDataset(
history_minutes=self.history_minutes, forecast_minutes=self.forecast_minutes
Expand All @@ -114,6 +132,7 @@ def train_dataloader(self):
return torch.utils.data.DataLoader(train_dataset, **self.dataloader_config)

def val_dataloader(self):
"""A data loader for the validation data"""
if self.fake_data:
val_dataset = FakeDataset(
history_minutes=self.history_minutes, forecast_minutes=self.forecast_minutes
Expand All @@ -133,6 +152,7 @@ def val_dataloader(self):
return torch.utils.data.DataLoader(val_dataset, **self.dataloader_config)

def test_dataloader(self):
"""A data loader for the testing data"""
if self.fake_data:
test_dataset = FakeDataset(
history_minutes=self.history_minutes, forecast_minutes=self.forecast_minutes
Expand All @@ -154,7 +174,7 @@ def test_dataloader(self):


class FakeDataset(torch.utils.data.Dataset):
"""Fake dataset."""
"""Fake dataset with random data, useful for testing."""

def __init__(
self,
Expand All @@ -166,6 +186,7 @@ def __init__(
history_minutes=30,
forecast_minutes=30,
):
"""Initialize a fake dataset"""
self.batch_size = batch_size
if history_minutes is None or forecast_minutes is None:
history_minutes = 30 # Half an hour
Expand All @@ -179,13 +200,15 @@ def __init__(
self.length = length

def __len__(self):
"""Length of dataset"""
return self.length

def per_worker_init(self, worker_id: int):
"""Not implemented"""
pass

def __getitem__(self, idx):

"""Get data at the index"""
x = {
SATELLITE_DATA: torch.randn(
self.batch_size, self.seq_length, self.width, self.height, self.number_sat_channels
Expand Down
23 changes: 14 additions & 9 deletions satflow/data/datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""SatFlowDataset"""
from typing import Tuple, Union, List, Optional

import numpy as np
Expand All @@ -18,9 +19,7 @@


class SatFlowDataset(NetCDFDataset):
"""Loads data saved by the `prepare_ml_training_data.py` script.
Adapted from predict_pv_yield
"""
"""Loads data saved by the `prepare_ml_training_data.py` script. Adapted from predict_pv_yield"""

def __init__(
self,
Expand All @@ -45,13 +44,18 @@ def __init__(
combine_inputs: bool = False,
):
"""
Initialize SatFlowDataSet

Args:
n_batches: Number of batches available on disk.
src_path: The full path (including 'gs://') to the data on
Google Cloud storage.
tmp_path: The full path to the local temporary directory
(on a local filesystem).
batch_size: Batch size, if requested, will subset data along batch dimension
n_batches: Number of batches available on disk.
src_path: The full path (including 'gs://') to the data on Google Cloud storage.
tmp_path: The full path to the local temporary directory (on a local filesystem).
configuration: configuration values
cloud: name of cloud provider. Default is "gcp".
required_keys: Tuple or list of keys required in the example for it to be considered usable
history_minutes: How many past minutes of data to use, if subsetting the batch. Default is 30.
forecast_minutes: How many future minutes of data to use, if reducing the amount of forecast time. Default is 60.
combine_inputs: Default is False.
"""
super().__init__(
n_batches,
Expand All @@ -69,6 +73,7 @@ def __init__(
self.current_timestep_index = (history_minutes // 5) + 1

def __getitem__(self, batch_idx: int):
"""Get data at the index"""
batch = super().__getitem__(batch_idx)

# Need to partition out past and future sat images here, along with the rest of the data
Expand Down
1 change: 1 addition & 0 deletions satflow/data/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Data utility functions"""
Loading