8
8
9
9
import warnings
10
10
from dataclasses import dataclass , field
11
- from typing import Any , Union
11
+ from typing import Any
12
12
13
13
import numpy as np
14
14
import pandas as pd
@@ -208,9 +208,7 @@ def _assert_not_empty(self):
208
208
if self .is_empty ():
209
209
raise ValueError ("MRData object is empty." )
210
210
211
- def is_cov_normalized (
212
- self , covs : Union [list [str ], str , None ] = None
213
- ) -> bool :
211
+ def is_cov_normalized (self , covs : list [str ] | str | None = None ) -> bool :
214
212
"""Return true when covariates are normalized."""
215
213
if covs is None :
216
214
covs = list (self .covs .keys ())
@@ -237,11 +235,11 @@ def reset(self):
237
235
def load_df (
238
236
self ,
239
237
data : pd .DataFrame ,
240
- col_obs : Union [ str , None ] = None ,
241
- col_obs_se : Union [ str , None ] = None ,
242
- col_covs : Union [ list [str ], None ] = None ,
243
- col_study_id : Union [ str , None ] = None ,
244
- col_data_id : Union [ str , None ] = None ,
238
+ col_obs : str | None = None ,
239
+ col_obs_se : str | None = None ,
240
+ col_covs : list [str ] | None = None ,
241
+ col_study_id : str | None = None ,
242
+ col_data_id : str | None = None ,
245
243
):
246
244
"""Load data from data frame."""
247
245
self .reset ()
@@ -273,10 +271,10 @@ def load_df(
273
271
def load_xr (
274
272
self ,
275
273
data ,
276
- var_obs : Union [ str , None ] = None ,
277
- var_obs_se : Union [ str , None ] = None ,
278
- var_covs : Union [ list [str ], None ] = None ,
279
- coord_study_id : Union [ str , None ] = None ,
274
+ var_obs : str | None = None ,
275
+ var_obs_se : str | None = None ,
276
+ var_covs : list [str ] | None = None ,
277
+ coord_study_id : str | None = None ,
280
278
):
281
279
"""Load data from xarray."""
282
280
self .reset ()
@@ -314,11 +312,11 @@ def to_df(self) -> pd.DataFrame:
314
312
315
313
return df
316
314
317
- def has_covs (self , covs : Union [ list [str ], str ] ) -> bool :
315
+ def has_covs (self , covs : list [str ] | str ) -> bool :
318
316
"""If the data has the provided covariates.
319
317
320
318
Args:
321
- covs (Union[ list[str], str] ):
319
+ covs (list[str] | str):
322
320
list of covariate names or one covariate name.
323
321
324
322
Returns:
@@ -330,11 +328,11 @@ def has_covs(self, covs: Union[list[str], str]) -> bool:
330
328
else :
331
329
return all ([cov in self .covs for cov in covs ])
332
330
333
- def has_studies (self , studies : Union [ list [Any ], Any ] ) -> bool :
331
+ def has_studies (self , studies : list [Any ] | Any ) -> bool :
334
332
"""If the data has provided study_id
335
333
336
334
Args:
337
- studies Union[ list[Any], Any] :
335
+ studies list[Any] | Any:
338
336
list of studies or one study.
339
337
340
338
Returns:
@@ -346,7 +344,7 @@ def has_studies(self, studies: Union[list[Any], Any]) -> bool:
346
344
else :
347
345
return all ([study in self .studies for study in studies ])
348
346
349
- def _assert_has_covs (self , covs : Union [ list [str ], str ] ):
347
+ def _assert_has_covs (self , covs : list [str ] | str ):
350
348
"""Assert has covariates otherwise raise ValueError."""
351
349
if not self .has_covs (covs ):
352
350
covs = to_list (covs )
@@ -355,7 +353,7 @@ def _assert_has_covs(self, covs: Union[list[str], str]):
355
353
f"MRData object do not contain covariates: { missing_covs } ."
356
354
)
357
355
358
- def _assert_has_studies (self , studies : Union [ list [Any ], Any ] ):
356
+ def _assert_has_studies (self , studies : list [Any ] | Any ):
359
357
"""Assert has studies otherwise raise ValueError."""
360
358
if not self .has_studies (studies ):
361
359
studies = to_list (studies )
@@ -366,11 +364,11 @@ def _assert_has_studies(self, studies: Union[list[Any], Any]):
366
364
f"MRData object do not contain studies: { missing_studies } ."
367
365
)
368
366
369
- def get_covs (self , covs : Union [ list [str ], str ] ) -> np .ndarray :
367
+ def get_covs (self , covs : list [str ] | str ) -> np .ndarray :
370
368
"""Get covariate matrix.
371
369
372
370
Args:
373
- covs (Union[ list[str], str] ):
371
+ covs (list[str] | str):
374
372
list of covariate names or one covariate name.
375
373
376
374
Returns:
@@ -385,11 +383,11 @@ def get_covs(self, covs: Union[list[str], str]) -> np.ndarray:
385
383
[self .covs [cov_names ][:, None ] for cov_names in covs ]
386
384
)
387
385
388
- def get_study_data (self , studies : Union [ list [Any ], Any ] ) -> "MRData" :
386
+ def get_study_data (self , studies : list [Any ] | Any ) -> "MRData" :
389
387
"""Get study specific data.
390
388
391
389
Args:
392
- studies (Union[ list[Any], Any] ): list of studies or one study.
390
+ studies (list[Any] | Any): list of studies or one study.
393
391
394
392
Returns
395
393
MRData: Data object contains the study specific data.
@@ -399,7 +397,7 @@ def get_study_data(self, studies: Union[list[Any], Any]) -> "MRData":
399
397
index = np .array ([study in studies for study in self .study_id ])
400
398
return self ._get_data (index )
401
399
402
- def normalize_covs (self , covs : Union [ list [str ], str , None ] = None ):
400
+ def normalize_covs (self , covs : list [str ] | str | None = None ):
403
401
"""Normalize covariates by the largest absolute value for each covariate."""
404
402
if covs is None :
405
403
covs = list (self .covs .keys ())
0 commit comments