Skip to content

Commit 8dccc93

Browse files
committed
remove all Union type hints
1 parent 07d9cc1 commit 8dccc93

File tree

4 files changed

+30
-36
lines changed

4 files changed

+30
-36
lines changed

src/mrtool/core/data.py

+22-24
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import warnings
1010
from dataclasses import dataclass, field
11-
from typing import Any, Union
11+
from typing import Any
1212

1313
import numpy as np
1414
import pandas as pd
@@ -208,9 +208,7 @@ def _assert_not_empty(self):
208208
if self.is_empty():
209209
raise ValueError("MRData object is empty.")
210210

211-
def is_cov_normalized(
212-
self, covs: Union[list[str], str, None] = None
213-
) -> bool:
211+
def is_cov_normalized(self, covs: list[str] | str | None = None) -> bool:
214212
"""Return true when covariates are normalized."""
215213
if covs is None:
216214
covs = list(self.covs.keys())
@@ -237,11 +235,11 @@ def reset(self):
237235
def load_df(
238236
self,
239237
data: pd.DataFrame,
240-
col_obs: Union[str, None] = None,
241-
col_obs_se: Union[str, None] = None,
242-
col_covs: Union[list[str], None] = None,
243-
col_study_id: Union[str, None] = None,
244-
col_data_id: Union[str, None] = None,
238+
col_obs: str | None = None,
239+
col_obs_se: str | None = None,
240+
col_covs: list[str] | None = None,
241+
col_study_id: str | None = None,
242+
col_data_id: str | None = None,
245243
):
246244
"""Load data from data frame."""
247245
self.reset()
@@ -273,10 +271,10 @@ def load_df(
273271
def load_xr(
274272
self,
275273
data,
276-
var_obs: Union[str, None] = None,
277-
var_obs_se: Union[str, None] = None,
278-
var_covs: Union[list[str], None] = None,
279-
coord_study_id: Union[str, None] = None,
274+
var_obs: str | None = None,
275+
var_obs_se: str | None = None,
276+
var_covs: list[str] | None = None,
277+
coord_study_id: str | None = None,
280278
):
281279
"""Load data from xarray."""
282280
self.reset()
@@ -314,11 +312,11 @@ def to_df(self) -> pd.DataFrame:
314312

315313
return df
316314

317-
def has_covs(self, covs: Union[list[str], str]) -> bool:
315+
def has_covs(self, covs: list[str] | str) -> bool:
318316
"""If the data has the provided covariates.
319317
320318
Args:
321-
covs (Union[list[str], str]):
319+
covs (list[str] | str):
322320
list of covariate names or one covariate name.
323321
324322
Returns:
@@ -330,11 +328,11 @@ def has_covs(self, covs: Union[list[str], str]) -> bool:
330328
else:
331329
return all([cov in self.covs for cov in covs])
332330

333-
def has_studies(self, studies: Union[list[Any], Any]) -> bool:
331+
def has_studies(self, studies: list[Any] | Any) -> bool:
334332
"""If the data has provided study_id
335333
336334
Args:
337-
studies Union[list[Any], Any]:
335+
studies list[Any] | Any:
338336
list of studies or one study.
339337
340338
Returns:
@@ -346,7 +344,7 @@ def has_studies(self, studies: Union[list[Any], Any]) -> bool:
346344
else:
347345
return all([study in self.studies for study in studies])
348346

349-
def _assert_has_covs(self, covs: Union[list[str], str]):
347+
def _assert_has_covs(self, covs: list[str] | str):
350348
"""Assert has covariates otherwise raise ValueError."""
351349
if not self.has_covs(covs):
352350
covs = to_list(covs)
@@ -355,7 +353,7 @@ def _assert_has_covs(self, covs: Union[list[str], str]):
355353
f"MRData object do not contain covariates: {missing_covs}."
356354
)
357355

358-
def _assert_has_studies(self, studies: Union[list[Any], Any]):
356+
def _assert_has_studies(self, studies: list[Any] | Any):
359357
"""Assert has studies otherwise raise ValueError."""
360358
if not self.has_studies(studies):
361359
studies = to_list(studies)
@@ -366,11 +364,11 @@ def _assert_has_studies(self, studies: Union[list[Any], Any]):
366364
f"MRData object do not contain studies: {missing_studies}."
367365
)
368366

369-
def get_covs(self, covs: Union[list[str], str]) -> np.ndarray:
367+
def get_covs(self, covs: list[str] | str) -> np.ndarray:
370368
"""Get covariate matrix.
371369
372370
Args:
373-
covs (Union[list[str], str]):
371+
covs (list[str] | str):
374372
list of covariate names or one covariate name.
375373
376374
Returns:
@@ -385,11 +383,11 @@ def get_covs(self, covs: Union[list[str], str]) -> np.ndarray:
385383
[self.covs[cov_names][:, None] for cov_names in covs]
386384
)
387385

388-
def get_study_data(self, studies: Union[list[Any], Any]) -> "MRData":
386+
def get_study_data(self, studies: list[Any] | Any) -> "MRData":
389387
"""Get study specific data.
390388
391389
Args:
392-
studies (Union[list[Any], Any]): list of studies or one study.
390+
studies (list[Any] | Any): list of studies or one study.
393391
394392
Returns
395393
MRData: Data object contains the study specific data.
@@ -399,7 +397,7 @@ def get_study_data(self, studies: Union[list[Any], Any]) -> "MRData":
399397
index = np.array([study in studies for study in self.study_id])
400398
return self._get_data(index)
401399

402-
def normalize_covs(self, covs: Union[list[str], str, None] = None):
400+
def normalize_covs(self, covs: list[str] | str | None = None):
403401
"""Normalize covariates by the largest absolute value for each covariate."""
404402
if covs is None:
405403
covs = list(self.covs.keys())

src/mrtool/core/model.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
"""
88

99
from copy import deepcopy
10-
from typing import Union
1110

1211
import numpy as np
1312
import pandas as pd
@@ -453,7 +452,7 @@ def __init__(
453452
data: MRData,
454453
ensemble_cov_model: CovModel,
455454
ensemble_knots: NDArray,
456-
cov_models: Union[list[CovModel], None] = None,
455+
cov_models: list[CovModel] | None = None,
457456
inlier_pct: float = 1.0,
458457
):
459458
"""Constructor of `MRBeRT`
@@ -462,7 +461,7 @@ def __init__(
462461
data (MRData): Data for meta-regression.
463462
ensemble_cov_model (CovModel):
464463
Covariates model which will be used with ensemble.
465-
cov_models (Union[list[CovModel], None], optional):
464+
cov_models (list[CovModel] | None, optional):
466465
Other covariate models, assume to be mutual exclusive with ensemble_cov_mdoel.
467466
inlier_pct (float): A float number between 0 and 1 indicate the percentage of inliers.
468467
"""

src/mrtool/core/utils.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
`utils` module of the `mrtool` package.
66
"""
77

8-
from typing import Any, Union
8+
from typing import Any
99

1010
import numpy as np
1111
import pandas as pd
@@ -294,7 +294,7 @@ def avg_integral(mat, spline=None, use_spline_intercept=False):
294294
def sample_knots(
295295
num_knots: int,
296296
knot_bounds: np.ndarray,
297-
min_dist: Union[float, np.ndarray],
297+
min_dist: float | np.ndarray,
298298
num_samples: int = 1,
299299
) -> np.ndarray:
300300
"""Sample knot vectors given a set of rules.
@@ -367,9 +367,7 @@ def _check_knot_bounds(num_knots: int, knot_bounds: np.ndarray) -> np.ndarray:
367367
return knot_bounds
368368

369369

370-
def _check_min_dist(
371-
num_knots: int, min_dist: Union[float, np.ndarray]
372-
) -> np.ndarray:
370+
def _check_min_dist(num_knots: int, min_dist: float | np.ndarray) -> np.ndarray:
373371
"""Check knot min_dist."""
374372
if np.isscalar(min_dist):
375373
min_dist = np.tile(min_dist, num_knots + 1)

src/mrtool/cov_selection/covfinder.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import warnings
88
from copy import deepcopy
9-
from typing import Union
109

1110
import numpy as np
1211

@@ -23,7 +22,7 @@ def __init__(
2322
self,
2423
data: MRData,
2524
covs: list[str],
26-
pre_selected_covs: Union[list[str], None] = None,
25+
pre_selected_covs: list[str] | None = None,
2726
normalized_covs: bool = True,
2827
num_samples: int = 1000,
2928
laplace_threshold: float = 1e-5,
@@ -34,7 +33,7 @@ def __init__(
3433
beta_gprior: dict[str, np.ndarray] = None,
3534
beta_gprior_std: float = 1.0,
3635
bias_zero: bool = False,
37-
use_re: Union[dict, None] = None,
36+
use_re: dict | None = None,
3837
):
3938
"""Covariate Finder.
4039
@@ -59,7 +58,7 @@ def __init__(
5958
beta_gprior_std (float, optional): Loose beta Gaussian prior standard deviation. Default to 1.
6059
bias_zero (bool, optional):
6160
If `True`, fit when specify the Gaussian prior it will be mean zero. Default to `False`.
62-
use_re (Union[dict, None], optional):
61+
use_re (dict | None, optional):
6362
A dictionary of use_re for each covariate. When `None` we have an uninformative prior
6463
for the random effects variance. Default to `None`.
6564
"""

0 commit comments

Comments
 (0)