Skip to content

Commit

Permalink
Fix mask pass
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Jul 2, 2024
1 parent d5ba28c commit edd5841
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
6 changes: 4 additions & 2 deletions supirfactor_dynamical/datasets/stratified_file_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def load_file(
_H5ADFileLoader.load_layer(
file_handle,
_elayer,
obs_include_mask
obs_include_mask,
feature_mask=feature_mask
)
for _elayer in extra_layers
]
Expand Down Expand Up @@ -366,7 +367,8 @@ def __init__(
layer=file_data_layer,
extra_layers=yield_extra_layers,
append_obs=True,
obs_include_mask=obs_include_mask
obs_include_mask=obs_include_mask,
feature_mask=feature_mask
)
self.yields_tuple = len(self.loaded_data) > 1

Expand Down
46 changes: 46 additions & 0 deletions supirfactor_dynamical/tests/test_loader_h5ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,52 @@ def test_load_h5(self):
[3, 3, 3, 3]
)

def test_combine_class(self):

dataset = StratifySingleFileDataset(
self.filename,
['strat'],
yield_obs_cats=['strat'],
combine_categories={'D': 'C'},
random_state=10
)

self.assertEqual(
len(dataset.loaded_data),
2
)

self.assertEqual(
len(dataset.loaded_data[0]),
100
)

self.assertEqual(
[32, 32, 36],
list(map(len, dataset.stratification_group_indexes))
)

_classes = []

for i, (v, c) in enumerate(dataset):
self.assertEqual(
v.shape,
(4,)
)
self.assertEqual(
c.shape,
(4,)
)
_classes.append(c)

_classes = np.vstack(_classes).sum(0).astype(int)
self.assertEqual(i, 95)

npt.assert_equal(
_classes,
[32, 32, 29, 3]
)

def test_h5_mask(self):

dataset = StratifySingleFileDataset(
Expand Down

0 comments on commit edd5841

Please sign in to comment.