Skip to content

Commit

Permalink
Fix all typing errors and run mypy in pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelosthege authored and LarsHalle committed Jun 19, 2023
1 parent 5d1277d commit 34fb75a
Show file tree
Hide file tree
Showing 11 changed files with 159 additions and 105 deletions.
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ repos:
rev: 22.3.0
hooks:
- id: black
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0
hooks:
- id: mypy
50 changes: 30 additions & 20 deletions bletl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import urllib.request
import warnings
from collections.abc import Iterable
from typing import Optional, Union
from typing import Optional, Sequence, Union

import numpy
import pandas
Expand Down Expand Up @@ -92,14 +92,14 @@ def get_parser(filepath: Union[str, pathlib.Path]) -> BLDParser:
def _parse(
filepath: str,
drop_incomplete_cycles: bool,
lot_number: int,
temp: int,
cal_0: float = None,
cal_100: float = None,
phi_min: float = None,
phi_max: float = None,
pH_0: float = None,
dpH: float = None,
lot_number: Optional[int],
temp: Optional[int],
cal_0: Optional[float] = None,
cal_100: Optional[float] = None,
phi_min: Optional[float] = None,
phi_max: Optional[float] = None,
pH_0: Optional[float] = None,
dpH: Optional[float] = None,
) -> BLData:
"""Parses a raw BioLector CSV file into a BLData object.
Expand Down Expand Up @@ -138,29 +138,39 @@ def _parse(
When the file contents do not match with a known BioLector result file format.
"""
parser = get_parser(filepath)
data = parser.parse(filepath, lot_number, temp, cal_0, cal_100, phi_min, phi_max, pH_0, dpH)
data = parser.parse(
filepath,
lot_number=lot_number,
temp=temp,
cal_0=cal_0,
cal_100=cal_100,
phi_min=phi_min,
phi_max=phi_max,
pH_0=pH_0,
dpH=dpH,
)

if (not data.measurements.empty) and drop_incomplete_cycles:
index_names, measurements = utils._unindex(data.measurements)
latest_full_cycle = utils._last_full_cycle(measurements)
measurements = measurements[measurements.cycle <= latest_full_cycle]
data._measurements = utils._reindex(measurements, index_names)
data._measurements = utils._reindex(measurements, index_names) # type: ignore

return data


def parse(
filepaths,
filepaths: Union[str, Sequence[str]],
*,
drop_incomplete_cycles: bool = True,
lot_number: int = None,
temp: int = None,
cal_0: float = None,
cal_100: float = None,
phi_min: float = None,
phi_max: float = None,
pH_0: float = None,
dpH: float = None,
lot_number: Optional[int] = None,
temp: Optional[int] = None,
cal_0: Optional[float] = None,
cal_100: Optional[float] = None,
phi_min: Optional[float] = None,
phi_max: Optional[float] = None,
pH_0: Optional[float] = None,
dpH: Optional[float] = None,
) -> BLData:
"""Parses a raw BioLector CSV file into a BLData object and applies calibration.
Expand Down
79 changes: 43 additions & 36 deletions bletl/growth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import typing
from typing import Dict, Optional, Sequence, Tuple, Union

import arviz
import calibr8
Expand All @@ -10,7 +11,7 @@
try:
import pytensor.tensor as pt
except ModuleNotFoundError:
import aesara.tensor as pt
import aesara.tensor as pt # type: ignore


_log = logging.getLogger(__file__)
Expand All @@ -22,13 +23,13 @@ class GrowthRateResult:
def __init__(
self,
*,
t_data: numpy.ndarray,
t_segments: numpy.ndarray,
y: numpy.ndarray,
t_data: Union[Sequence[float], numpy.ndarray],
t_segments: Union[Sequence[float], numpy.ndarray],
y: Union[Sequence[float], numpy.ndarray],
calibration_model: calibr8.CalibrationModel,
switchpoints: typing.Dict[float, str],
switchpoints: Dict[float, str],
pmodel: pm.Model,
theta_map: dict,
theta_map: Dict[str, numpy.ndarray],
):
"""Creates a result object of a growth rate analysis.
Expand All @@ -47,9 +48,9 @@ def __init__(
theta_map : dict
the PyMC MAP estimate
"""
self._t_data = t_data
self._t_segments = t_segments
self._y = y
self._t_data = numpy.asarray(t_data)
self._t_segments = numpy.asarray(t_segments)
self._y = numpy.asarray(y)
self._switchpoints = switchpoints
self.calibration_model = calibration_model
self._pmodel = pmodel
Expand All @@ -73,17 +74,17 @@ def y(self) -> numpy.ndarray:
return self._y

@property
def switchpoints(self) -> typing.Dict[float, str]:
def switchpoints(self) -> Dict[float, str]:
"""Dictionary (by time) of known and detected switchpoints."""
return self._switchpoints

@property
def known_switchpoints(self) -> typing.Tuple[float]:
def known_switchpoints(self) -> Tuple[float, ...]:
"""Time values of previously known switchpoints in the model."""
return tuple(t for t, label in self.switchpoints.items() if label != "detected")

@property
def detected_switchpoints(self) -> typing.Tuple[float]:
def detected_switchpoints(self) -> Tuple[float, ...]:
"""Time values of switchpoints that were autodetected from the fit."""
return tuple(t for t, label in self.switchpoints.items() if label == "detected")

Expand All @@ -93,12 +94,12 @@ def pmodel(self) -> pm.Model:
return self._pmodel

@property
def theta_map(self) -> dict:
def theta_map(self) -> Dict[str, numpy.ndarray]:
"""MAP estimate of the model parameters."""
return self._theta_map

@property
def idata(self) -> typing.Optional[arviz.InferenceData]:
def idata(self) -> Optional[arviz.InferenceData]:
"""ArviZ InferenceData object of the MCMC trace."""
return self._idata

Expand All @@ -113,18 +114,20 @@ def x_map(self) -> numpy.ndarray:
return self.theta_map["X"]

@property
def mu_mcmc(self) -> typing.Optional[numpy.ndarray]:
def mu_mcmc(self) -> Optional[numpy.ndarray]:
"""Posterior samples of growth rates in segments between data points."""
if not self.idata:
return None
assert hasattr(self.idata, "posterior")
return self.idata.posterior.mu_t.stack(sample=("chain", "draw")).values.T

@property
def x_mcmc(self) -> typing.Optional[numpy.ndarray]:
def x_mcmc(self) -> Optional[numpy.ndarray]:
"""Posterior samples of biomass curve."""
if not self.idata:
if self.idata is None:
return None
return self._idata.posterior["X"].stack(sample=("chain", "draw")).T
assert hasattr(self.idata, "posterior")
return self.idata.posterior["X"].stack(sample=("chain", "draw")).T

def sample(self, **kwargs) -> None:
"""Runs MCMC sampling with default settings on the growth model.
Expand Down Expand Up @@ -157,8 +160,8 @@ def _make_random_walk(
nu: float = 1,
length: int,
student_t: bool,
initval: numpy.ndarray = None,
dims: typing.Optional[str] = None,
initval: Optional[numpy.ndarray] = None,
dims: Optional[str] = None,
):
"""Create a random walk with either a Normal or Student-t distribution.
Expand Down Expand Up @@ -215,7 +218,11 @@ def _make_random_walk(


def _get_smoothed_mu(
t: numpy.ndarray, y: numpy.ndarray, cm_cdw: calibr8.CalibrationModel, *, clip=0.5
t: Sequence[float],
y: Sequence[float],
cm_cdw: calibr8.CalibrationModel,
*,
clip: float = 0.5,
) -> numpy.ndarray:
"""Calculate a rough estimate of the specific growth rate from smoothed observations.
Expand All @@ -236,10 +243,10 @@ def _get_smoothed_mu(
A vector of specific growth rates.
"""
# apply moving average to reduce backscatter noise
y = numpy.convolve(y, numpy.ones(5) / 5, "same")
yarr = numpy.convolve(y, numpy.ones(5) / 5, "same")

# convert to biomass
X = cm_cdw.predict_independent(y)
X = cm_cdw.predict_independent(yarr)

# calculate growth rate
dX = numpy.diff(X)
Expand All @@ -259,17 +266,17 @@ def _get_smoothed_mu(


def fit_mu_t(
t: typing.Sequence[float],
y: typing.Sequence[float],
t: Sequence[float],
y: Sequence[float],
calibration_model: calibr8.CalibrationModel,
*,
switchpoints: typing.Optional[typing.Union[typing.Sequence[float], typing.Dict[float, str]]] = None,
switchpoints: Optional[Union[Sequence[float], Dict[float, str]]] = None,
mcmc_samples: int = 0,
mu_prior: float = 0,
drift_scale: float,
nu: float = 5,
x0_prior: float = 0.25,
student_t: typing.Optional[bool] = None,
student_t: Optional[bool] = None,
switchpoint_prob: float = 0.01,
replicate_id: str = "unnamed",
):
Expand Down Expand Up @@ -357,7 +364,7 @@ def fit_mu_t(
mu_segments = []
i_from = 0
for i, t_switch in enumerate(t_switchpoints_known):
i_to = numpy.argmax(t > t_switch)
i_to = int(numpy.argmax(t > t_switch))
i_len = len(t[i_from:i_to])
name = f"mu_phase_{i}"
slc = slice(i_from, i_to)
Expand Down Expand Up @@ -460,10 +467,10 @@ def fit_mu_t(

def detect_switchpoints(
switchpoint_prob: float,
t_data: typing.Sequence[float],
t_data: Sequence[float],
pmodel: pm.Model,
theta_map: typing.Dict[str, numpy.ndarray],
) -> typing.Dict[float, str]:
theta_map: Dict[str, numpy.ndarray],
) -> Dict[float, str]:
"""Helper function to detect switchpoints from a fitted random walk.
Parameters
Expand Down Expand Up @@ -509,15 +516,15 @@ def detect_switchpoints(
# To get our <number of segments> length vector to align with the <number of points>,
# we prepend a 0.5 as a placeholder for the CDF of the initial point of the random walk.
cdf_evals += [0.5, *numpy.exp(logcdfs)]
cdf_evals = numpy.array(cdf_evals)
if len(cdf_evals) != len(t_data) - 1:
cdf_evals_arr = numpy.array(cdf_evals)
if len(cdf_evals_arr) != len(t_data) - 1:
raise Exception(
f"Failed to find all random walk segments. Found {len(cdf_evals)}, expected {len(t_data) - 1}."
f"Failed to find all random walk segments. Found {len(cdf_evals_arr)}, expected {len(t_data) - 1}."
)
# Filter for the elements that lie outside of the [0.005, 0.995] interval (if switchpoint_prob=0.01).
significance_mask = numpy.logical_or(
cdf_evals < (switchpoint_prob / 2),
cdf_evals > (1 - switchpoint_prob / 2),
cdf_evals_arr < (switchpoint_prob / 2),
cdf_evals_arr > (1 - switchpoint_prob / 2),
)
# Collect switchpoint information from points with significant CDF values.
# Here we don't need to filter known switchpoints, because these correspond to the first
Expand Down
30 changes: 16 additions & 14 deletions bletl/parsing/bl1.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ def calibrate_with_lot(self, data: BLData, lot_number: Optional[int] = None, tem
def calibrate_with_parameters(
self,
data: BLData,
cal_0: float = None,
cal_100: float = None,
phi_min: float = None,
phi_max: float = None,
pH_0: float = None,
dpH: float = None,
cal_0: Optional[float] = None,
cal_100: Optional[float] = None,
phi_min: Optional[float] = None,
phi_max: Optional[float] = None,
pH_0: Optional[float] = None,
dpH: Optional[float] = None,
):
def process_backscatter(raw_data_df, cycle_ref_df, global_ref):
"""
Expand Down Expand Up @@ -182,14 +182,14 @@ def process_DO(raw_data_df, cal_0, cal_100):
def parse(
self,
filepath,
lot_number: int = None,
temp: int = None,
cal_0: float = None,
cal_100: float = None,
phi_min: float = None,
phi_max: float = None,
pH_0: float = None,
dpH: float = None,
lot_number: Optional[int] = None,
temp: Optional[int] = None,
cal_0: Optional[float] = None,
cal_100: Optional[float] = None,
phi_min: Optional[float] = None,
phi_max: Optional[float] = None,
pH_0: Optional[float] = None,
dpH: Optional[float] = None,
):
headerlines, data = split_header_data(filepath)

Expand Down Expand Up @@ -476,6 +476,8 @@ def fetch_calibration_data(lot_number: int, temp: int):
Dictionary containing calibration data.
Can be readily used in calibration function.
"""
assert utils.__spec__ is not None
assert utils.__spec__.origin is not None
module_path = pathlib.Path(utils.__spec__.origin).parents[0]
calibration_file = pathlib.Path(module_path, "cache", "CalibrationLot.ini")

Expand Down
Loading

0 comments on commit 34fb75a

Please sign in to comment.