From edd5841a3d3d902233d21ee8e9e1426b9f452c9a Mon Sep 17 00:00:00 2001 From: asistradition Date: Tue, 2 Jul 2024 11:02:41 -0400 Subject: [PATCH] Fix mask pass --- .../datasets/stratified_file_dataset.py | 6 ++- .../tests/test_loader_h5ad.py | 46 +++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/supirfactor_dynamical/datasets/stratified_file_dataset.py b/supirfactor_dynamical/datasets/stratified_file_dataset.py index 41ed408..5e8d310 100644 --- a/supirfactor_dynamical/datasets/stratified_file_dataset.py +++ b/supirfactor_dynamical/datasets/stratified_file_dataset.py @@ -67,7 +67,8 @@ def load_file( _H5ADFileLoader.load_layer( file_handle, _elayer, - obs_include_mask + obs_include_mask, + feature_mask=feature_mask ) for _elayer in extra_layers ] @@ -366,7 +367,8 @@ def __init__( layer=file_data_layer, extra_layers=yield_extra_layers, append_obs=True, - obs_include_mask=obs_include_mask + obs_include_mask=obs_include_mask, + feature_mask=feature_mask ) self.yields_tuple = len(self.loaded_data) > 1 diff --git a/supirfactor_dynamical/tests/test_loader_h5ad.py b/supirfactor_dynamical/tests/test_loader_h5ad.py index 6fff91a..c3b70bc 100644 --- a/supirfactor_dynamical/tests/test_loader_h5ad.py +++ b/supirfactor_dynamical/tests/test_loader_h5ad.py @@ -789,6 +789,52 @@ def test_load_h5(self): [3, 3, 3, 3] ) + def test_combine_class(self): + + dataset = StratifySingleFileDataset( + self.filename, + ['strat'], + yield_obs_cats=['strat'], + combine_categories={'D': 'C'}, + random_state=10 + ) + + self.assertEqual( + len(dataset.loaded_data), + 2 + ) + + self.assertEqual( + len(dataset.loaded_data[0]), + 100 + ) + + self.assertEqual( + [32, 32, 36], + list(map(len, dataset.stratification_group_indexes)) + ) + + _classes = [] + + for i, (v, c) in enumerate(dataset): + self.assertEqual( + v.shape, + (4,) + ) + self.assertEqual( + c.shape, + (4,) + ) + _classes.append(c) + + _classes = np.vstack(_classes).sum(0).astype(int) + self.assertEqual(i, 95) + + npt.assert_equal( + _classes, + [32, 32, 29, 3] + ) + def test_h5_mask(self): dataset = StratifySingleFileDataset(