Skip to content

Commit 8b5d8b0

Browse files
committed
switch to use NDArray
1 parent 216297b commit 8b5d8b0

File tree

6 files changed

+64
-40
lines changed

6 files changed

+64
-40
lines changed

.pre-commit-config.yaml

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
repos:
2+
- repo: https://github.com/astral-sh/ruff-pre-commit
3+
rev: v0.4.2
4+
hooks:
5+
- id: ruff
6+
args: [ --fix ]
7+
- id: ruff-format
8+
- repo: https://github.com/pre-commit/pre-commit-hooks
9+
rev: v4.6.0
10+
hooks:
11+
- id: trailing-whitespace
12+
- id: end-of-file-fixer
13+
- repo: https://github.com/pre-commit/mirrors-mypy
14+
rev: v1.10.0
15+
hooks:
16+
- id: mypy
17+
files: ^src

src/mrtool/core/cov_model.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import numpy as np
1010
import xspline
11+
from numpy.typing import NDArray
1112

1213
from . import utils
1314
from .data import MRData
@@ -438,7 +439,7 @@ def has_data(self):
438439
return True
439440

440441
def create_spline(
441-
self, data: MRData, spline_knots: np.ndarray = None
442+
self, data: MRData, spline_knots: NDArray | None = None
442443
) -> xspline.XSpline:
443444
"""Create spline given current spline parameters.
444445
Parameters
@@ -525,7 +526,7 @@ def create_spline(
525526

526527
return spline
527528

528-
def create_design_mat(self, data) -> tuple[np.ndarray, np.ndarray]:
529+
def create_design_mat(self, data) -> tuple[NDArray, NDArray]:
529530
"""Create design matrix.
530531
Parameters
531532
----------
@@ -564,7 +565,7 @@ def create_z_mat(self, data):
564565
"Cannot use create_z_mat directly in CovModel class."
565566
)
566567

567-
def create_constraint_mat(self) -> tuple[np.ndarray, np.ndarray]:
568+
def create_constraint_mat(self) -> tuple[NDArray, NDArray]:
568569
"""Create constraint matrix.
569570
Returns:
570571
tuple{numpy.ndarray, numpy.ndarray}:
@@ -679,7 +680,7 @@ def create_constraint_mat(self) -> tuple[np.ndarray, np.ndarray]:
679680

680681
return c_mat, c_val
681682

682-
def create_regularization_mat(self) -> tuple[np.ndarray, np.ndarray]:
683+
def create_regularization_mat(self) -> tuple[NDArray, NDArray]:
683684
"""Create constraint matrix.
684685
Returns:
685686
tuple{numpy.ndarray, numpy.ndarray}:

src/mrtool/core/data.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import numpy as np
1414
import pandas as pd
15+
from numpy.typing import NDArray
1516

1617
from .utils import empty_array, expand_array, is_numeric_array, to_list
1718

@@ -20,11 +21,11 @@
2021
class MRData:
2122
"""Data for simple linear mixed effects model."""
2223

23-
obs: np.ndarray = field(default_factory=empty_array)
24-
obs_se: np.ndarray = field(default_factory=empty_array)
25-
covs: dict[str, np.ndarray] = field(default_factory=dict)
26-
study_id: np.ndarray = field(default_factory=empty_array)
27-
data_id: np.ndarray = field(default_factory=empty_array)
24+
obs: NDArray = field(default_factory=empty_array)
25+
obs_se: NDArray = field(default_factory=empty_array)
26+
covs: dict[str, NDArray] = field(default_factory=dict)
27+
study_id: NDArray = field(default_factory=empty_array)
28+
data_id: NDArray = field(default_factory=empty_array)
2829
cov_scales: dict[str, float] = field(init=False, default_factory=dict)
2930

3031
def __post_init__(self):
@@ -121,7 +122,7 @@ def _get_study_structure(self):
121122
)
122123
self._sort_by_study_id()
123124

124-
def _sort_data(self, index: np.ndarray):
125+
def _sort_data(self, index: NDArray):
125126
"""Sort the object.
126127
127128
Parameters
@@ -166,7 +167,7 @@ def _remove_nan_in_covs(self):
166167
index = index | cov_index
167168
self._remove_data(index)
168169

169-
def _remove_data(self, index: np.ndarray):
170+
def _remove_data(self, index: NDArray):
170171
"""Remove the data point by index.
171172
172173
Parameters
@@ -186,7 +187,7 @@ def _remove_data(self, index: np.ndarray):
186187
self.study_id = self.study_id[keep_index]
187188
self.data_id = self.data_id[keep_index]
188189

189-
def _get_data(self, index: np.ndarray) -> "MRData":
190+
def _get_data(self, index: NDArray) -> "MRData":
190191
"""Get the data point by index.
191192
192193
Parameters
@@ -383,7 +384,7 @@ def _assert_has_studies(self, studies: list[Any] | Any):
383384
f"MRData object do not contain studies: {missing_studies}."
384385
)
385386

386-
def get_covs(self, covs: list[str] | str) -> np.ndarray:
387+
def get_covs(self, covs: list[str] | str) -> NDArray:
387388
"""Get covariate matrix.
388389
389390
Parameters
@@ -393,7 +394,7 @@ def get_covs(self, covs: list[str] | str) -> np.ndarray:
393394
394395
Returns
395396
-------
396-
np.ndarray
397+
NDArray
397398
Covariates matrix, in the column fashion.
398399
399400
"""

src/mrtool/core/model.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88

99
from copy import deepcopy
10+
from typing import Sequence
1011

1112
import numpy as np
1213
import pandas as pd
@@ -22,7 +23,10 @@ class MRBRT:
2223
"""MR-BRT Object"""
2324

2425
def __init__(
25-
self, data: MRData, cov_models: list[CovModel], inlier_pct: float = 1.0
26+
self,
27+
data: MRData,
28+
cov_models: Sequence[CovModel],
29+
inlier_pct: float = 1.0,
2630
):
2731
"""Constructor of MRBRT.
2832
@@ -80,12 +84,12 @@ def __init__(
8084
)
8185

8286
# place holder for the limetr objective
83-
self.lt = None
84-
self.beta_soln = None
85-
self.gamma_soln = None
86-
self.u_soln = None
87-
self.w_soln = None
88-
self.re_soln = None
87+
self.lt: LimeTr
88+
self.beta_soln: NDArray
89+
self.gamma_soln: NDArray
90+
self.u_soln: NDArray
91+
self.w_soln: NDArray
92+
self.re_soln: NDArray
8993

9094
def attach_data(self, data=None):
9195
"""Attach data to cov_model."""

src/mrtool/core/utils.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import numpy as np
1111
import pandas as pd
12+
from numpy.typing import NDArray
1213

1314

1415
def get_cols(df, cols):
@@ -124,7 +125,7 @@ def sizes_to_indices(sizes):
124125
125126
Returns
126127
-------
127-
list[np.ndarray]
128+
list[NDArray]
128129
list the indices.
129130
130131
"""
@@ -325,29 +326,29 @@ def avg_integral(mat, spline=None, use_spline_intercept=False):
325326
# random knots
326327
def sample_knots(
327328
num_knots: int,
328-
knot_bounds: np.ndarray,
329-
min_dist: float | np.ndarray,
329+
knot_bounds: NDArray,
330+
min_dist: float | NDArray,
330331
num_samples: int = 1,
331-
) -> np.ndarray:
332+
) -> NDArray:
332333
"""Sample knot vectors given a set of rules.
333334
334335
Parameters
335336
----------
336337
num_knots : int
337338
Number of interior knots.
338-
knot_bounds : np.ndarray, shape(2,) or shape(`num_knots`,2)
339+
knot_bounds : NDArray, shape(2,) or shape(`num_knots`,2)
339340
Lower and upper bounds for knots. If shape(2,), boundary knots
340341
placed at `knot_bounds[0]` and `knot_bounds[1]`. If
341342
shape(`num_knots`,2), boundary knots placed at
342343
`knot_bounds[0, 0]` and `knot_bounds[-1, 1]`.
343-
min_dist : float or np.ndarray, shape(`num_knots`+1,)
344+
min_dist : float or NDArray, shape(`num_knots`+1,)
344345
Minimum distances between knots.
345346
num_samples : int, optional
346347
Number of knot vectors to sample. Default is 1.
347348
348349
Returns
349350
-------
350-
np.ndarray, shape(`num_samples`,`num_knots`+2)
351+
NDArray, shape(`num_samples`,`num_knots`+2)
351352
Sampled knot vectors.
352353
353354
"""
@@ -380,7 +381,7 @@ def _check_nums(num_name: str, num_val: int) -> None:
380381
raise ValueError(f"{num_name} must be at least 1")
381382

382383

383-
def _check_knot_bounds(num_knots: int, knot_bounds: np.ndarray) -> np.ndarray:
384+
def _check_knot_bounds(num_knots: int, knot_bounds: NDArray) -> NDArray:
384385
"""Check knot_bounds."""
385386
try:
386387
knot_bounds = np.asarray(knot_bounds, dtype=float)
@@ -399,7 +400,7 @@ def _check_knot_bounds(num_knots: int, knot_bounds: np.ndarray) -> np.ndarray:
399400
return knot_bounds
400401

401402

402-
def _check_min_dist(num_knots: int, min_dist: float | np.ndarray) -> np.ndarray:
403+
def _check_min_dist(num_knots: int, min_dist: float | NDArray) -> NDArray:
403404
"""Check knot min_dist."""
404405
if np.isscalar(min_dist):
405406
min_dist = np.tile(min_dist, num_knots + 1)
@@ -415,8 +416,8 @@ def _check_min_dist(num_knots: int, min_dist: float | np.ndarray) -> np.ndarray:
415416

416417

417418
def _check_feasibility(
418-
num_knots: int, knot_bounds: np.ndarray, min_dist: np.ndarray
419-
) -> tuple[np.ndarray, np.ndarray]:
419+
num_knots: int, knot_bounds: NDArray, min_dist: NDArray
420+
) -> tuple[NDArray, NDArray]:
420421
"""Check knot feasibility and get left and right boundaries."""
421422
if np.sum(min_dist) > knot_bounds[-1, 1] - knot_bounds[0, 0]:
422423
raise ValueError("min_dist cannot exceed knot_bounds")
@@ -561,7 +562,7 @@ def to_list(obj: Any) -> list[Any]:
561562
return [obj]
562563

563564

564-
def is_numeric_array(array: np.ndarray) -> bool:
565+
def is_numeric_array(array: NDArray) -> bool:
565566
"""Check if an array is numeric.
566567
567568
Parameters
@@ -590,8 +591,8 @@ def is_numeric_array(array: np.ndarray) -> bool:
590591

591592

592593
def expand_array(
593-
array: np.ndarray, shape: tuple[int], value: Any, name: str
594-
) -> np.ndarray:
594+
array: NDArray, shape: tuple[int], value: Any, name: str
595+
) -> NDArray:
595596
"""Expand array when it is empty.
596597
597598
Parameters
@@ -608,7 +609,7 @@ def expand_array(
608609
609610
Returns
610611
-------
611-
np.ndarray
612+
NDArray
612613
Expanded array.
613614
614615
"""
@@ -630,7 +631,7 @@ def expand_array(
630631
def ravel_dict(x: dict) -> dict:
631632
"""Ravel dictionary."""
632633
assert all([isinstance(k, str) for k in x.keys()])
633-
assert all([isinstance(v, np.ndarray) for v in x.values()])
634+
assert all([isinstance(v, NDArray) for v in x.values()])
634635
new_x = {}
635636
for k, v in x.items():
636637
if v.size == 1:

src/mrtool/cov_selection/covfinder.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
power_step_size: float = 0.5,
3131
inlier_pct: float = 1.0,
3232
alpha: float = 0.05,
33-
beta_gprior: dict[str, np.ndarray] = None,
33+
beta_gprior: dict[str, np.ndarray] | None = None,
3434
beta_gprior_std: float = 1.0,
3535
bias_zero: bool = False,
3636
use_re: dict | None = None,
@@ -106,7 +106,7 @@ def __init__(
106106
self.power_step_size = power_step_size
107107
self.powers = np.arange(*self.power_range, self.power_step_size)
108108

109-
self.num_covs = len(pre_selected_covs) + len(covs)
109+
self.num_covs = len(self.all_covs)
110110
if len(covs) == 0:
111111
warnings.warn(
112112
"There is no covariates to select, will return the pre-selected covariates."
@@ -117,7 +117,7 @@ def create_model(
117117
self,
118118
covs: list[str],
119119
prior_type: str = "Laplace",
120-
laplace_std: float = None,
120+
laplace_std: float | None = None,
121121
) -> MRBRT:
122122
"""Create Gaussian or Laplace model.
123123

0 commit comments

Comments
 (0)