diff --git a/src/arcade_collection/__main__.py b/src/arcade_collection/__main__.py index c552b21..e69de29 100644 --- a/src/arcade_collection/__main__.py +++ b/src/arcade_collection/__main__.py @@ -1,2 +0,0 @@ -if __name__ == "__main__": - print("hello world") diff --git a/src/arcade_collection/input/convert_to_cells_file.py b/src/arcade_collection/input/convert_to_cells_file.py index b3261d3..5c75155 100644 --- a/src/arcade_collection/input/convert_to_cells_file.py +++ b/src/arcade_collection/input/convert_to_cells_file.py @@ -10,6 +10,44 @@ def convert_to_cells_file( critical_height_distributions: dict[str, tuple[float, float]], state_thresholds: dict[str, float], ) -> list[dict]: + """ + Convert all samples to cell objects. + + For each cell id in samples, current volume and height are rescaled to + critical volume and critical height based on distribution means and standard + deviations. If reference volume and/or height exist for the cell id, those + values are used as the current values to be rescaled. Otherwise, current + volume is calculated from the number of voxel samples and current height is + calculated from the range of voxel coordinates along the z axis. + + Initial cell state and cell state phase for each cell are estimated based on + state thresholds, the current cell volume, and the critical cell volume. + + Cell object ids are reindexed starting with cell id 1. + + Parameters + ---------- + samples + Sample cell ids and coordinates. + reference + Reference values for volumes and heights. + volume_distributions + Map of volume means and standard deviations. + height_distributions + Map of height means and standard deviations. + critical_volume_distributions + Map of critical volume means and standard deviations. + critical_height_distributions + Map of critical height means and standard deviations. + state_thresholds + Critical volume fractions defining threshold between states. + + Returns + ------- + : + List of cell objects formatted for ARCADE. + """ + cells: list[dict] = [] samples_by_id = samples.groupby("id") @@ -42,30 +80,43 @@ def convert_to_cell( state_thresholds: dict[str, float], ) -> dict: """ - Convert samples to ARCADE .CELLS json format. + Convert samples to cell object. + + Current volume and height are rescaled to critical volume and critical + height based on distribution means and standard deviations. If reference + volume and/or height are provided (under the "DEFAULT" key), those values + are used as the current values to be rescaled. Otherwise, current volume is + calculated from the number of voxel samples and current height is calculated + from the range of voxel coordinates along the z axis. + + Initial cell state and cell state phase are estimated based on state + thresholds, the current cell volume, and the critical cell volume. Parameters ---------- cell_id Unique cell id. samples - Sample cell ids and coordinates. + Sample coordinates for a single object. reference - Reference data for conversion. - volume_distribution - Average and standard deviation of volume distributions. - height_distribution - Average and standard deviation of height distributions. - critical_volume_distribution - Average and standard deviation of critical volume distributions. - critical_height_distribution - Average and standard deviation of critical height distributions. + Reference data for cell. + volume_distributions + Map of volume means and standard deviations. + height_distributions + Map of height means and standard deviations. + critical_volume_distributions + Map of critical volume means and standard deviations. + critical_height_distributions + Map of critical height means and standard deviations. + state_thresholds + Critical volume fractions defining threshold between states. Returns ------- : - Dictionary in ARCADE .CELLS json format. + Cell object formatted for ARCADE. """ + volume = len(samples) height = samples.z.max() - samples.z.min() @@ -95,7 +146,7 @@ def convert_to_cell( "criticals": [critical_volume, critical_height], } - if "region" in samples.columns: + if "region" in samples.columns and not samples["region"].isnull().all(): regions = [ convert_to_cell_region( region, @@ -122,6 +173,39 @@ def convert_to_cell_region( critical_volume_distributions: dict[str, tuple[float, float]], critical_height_distributions: dict[str, tuple[float, float]], ) -> dict: + """ + Convert region samples to cell region object. + + Current region volume and height are rescaled to critical volume and + critical height based on distribution means and standard deviations. If + reference region volume and/or height are provided, those values are used as + the current values to be rescaled. Otherwise, current region volume is + calculated from the number of voxel samples and current region height is + calculated from the range of voxel coordinates along the z axis. + + Parameters + ---------- + region + Region name. + region_samples + Sample coordinates for region of a single object. + reference + Reference data for cell region. + volume_distributions + Map of volume means and standard deviations. + height_distributions + Map of height means and standard deviations. + critical_volume_distributions + Map of critical volume means and standard deviations. + critical_height_distributions + Map of critical height means and standard deviations. + + Returns + ------- + : + Cell region object formatted for ARCADE. + """ + region_volume = len(region_samples) region_height = region_samples.z.max() - region_samples.z.min() @@ -152,12 +236,14 @@ def get_cell_state( """ Estimates cell state based on cell volume. - The threshold fractions dictionary defines the monotonic thresholds - between different cell states. - For a given volume v, critical volume V, and states X1, X2, ..., XN with - corresponding, monotonic threshold fractions f1, f2, ..., fN, a cell is - assigned state Xi such that [f(i - 1) * V] <= v < [fi * V]. + The threshold fractions dictionary defines the monotonic thresholds between + different cell states. For a given volume v, critical volume V, and states + X1, X2, ..., XN with corresponding, monotonic threshold fractions f1, f2, + ..., fN, a cell is assigned state Xi such that [f(i - 1) * V] <= v < [fi * + V]. + Cells with v < f1 * V are assigned state X1. + Cells with v > fN * V are assigned state XN. Parameters @@ -174,6 +260,7 @@ def get_cell_state( : Cell state. """ + thresholds = [fraction * critical_volume for fraction in threshold_fractions.values()] states = list(threshold_fractions.keys()) @@ -203,6 +290,7 @@ def convert_value_distribution( : Estimated critical value. """ + source_avg, source_std = source_distribution target_avg, target_std = target_distribution z_scored_value = (value - source_avg) / source_std @@ -226,6 +314,7 @@ def filter_cell_reference(cell_id: int, reference: pd.DataFrame) -> dict: : Reference data for given cell id. """ + cell_reference = reference[reference["ID"] == cell_id].squeeze() cell_reference = cell_reference.to_dict() if not cell_reference.empty else {} return cell_reference diff --git a/src/arcade_collection/input/convert_to_locations_file.py b/src/arcade_collection/input/convert_to_locations_file.py index f31d0b9..acff995 100644 --- a/src/arcade_collection/input/convert_to_locations_file.py +++ b/src/arcade_collection/input/convert_to_locations_file.py @@ -4,6 +4,20 @@ def convert_to_locations_file(samples: pd.DataFrame) -> list[dict]: + """ + Convert all samples to location objects. + + Parameters + ---------- + samples + Sample cell ids and coordinates. + + Returns + ------- + : + List of location objects formatted for ARCADE. + """ + locations: list[dict] = [] samples_by_id = samples.groupby("id") @@ -15,23 +29,24 @@ def convert_to_locations_file(samples: pd.DataFrame) -> list[dict]: def convert_to_location(cell_id: int, samples: pd.DataFrame) -> dict: """ - Convert samples to ARCADE .LOCATIONS json format. + Convert samples to location object. Parameters ---------- cell_id Unique cell id. samples - Sample cell ids and coordinates. + Sample coordinates for a single object. Returns ------- : - Dictionary in ARCADE .LOCATIONS json format. + Location object formatted for ARCADE. """ + center = get_center_voxel(samples) - if "region" in samples.columns: + if "region" in samples.columns and not samples["region"].isnull().all(): voxels = [ {"region": region, "voxels": get_location_voxels(samples, region)} for region in samples["region"].unique() @@ -44,6 +59,7 @@ def convert_to_location(cell_id: int, samples: pd.DataFrame) -> dict: "center": center, "location": voxels, } + return location @@ -61,6 +77,7 @@ def get_center_voxel(samples: pd.DataFrame) -> tuple[int, int, int]: : Center voxel. """ + center_x = int(samples["x"].mean()) center_y = int(samples["y"].mean()) center_z = int(samples["z"].mean()) @@ -86,6 +103,7 @@ def get_location_voxels( : List of voxel coordinates. """ + if region is not None: region_samples = samples[samples["region"] == region] voxels_x = region_samples["x"] diff --git a/src/arcade_collection/input/generate_setup_file.py b/src/arcade_collection/input/generate_setup_file.py index 09153a0..9f54810 100644 --- a/src/arcade_collection/input/generate_setup_file.py +++ b/src/arcade_collection/input/generate_setup_file.py @@ -4,14 +4,42 @@ import numpy as np import pandas as pd +DEFAULT_POPULATION_ID = "X" +"""Default population ID used in setup file.""" + def generate_setup_file( - samples: pd.DataFrame, margins: tuple[int, int, int], potts_terms: list[str] + samples: pd.DataFrame, margins: tuple[int, int, int], terms: list[str] ) -> str: + """ + Create ARCADE setup file from samples, margins, and CPM Hamiltonian terms. + + Initial number of cells is determined by number of unique ids in samples. + Regions are included if samples contains valid regions. + + Parameters + ---------- + samples + Sample cell ids and coordinates. + margins + Margin size in x, y, and z directions. + terms + List of Potts Hamiltonian terms for setup file. + + Returns + ------- + : + Contents of ARCADE setup file. + """ + init = len(samples["id"].unique()) bounds = calculate_sample_bounds(samples, margins) - regions = samples["regions"].unique() if "regions" in samples else None - setup = make_setup_file(init, bounds, potts_terms, regions) + regions = ( + samples["region"].unique() + if "region" in samples.columns and not samples["region"].isnull().all() + else None + ) + setup = make_setup_file(init, bounds, terms, regions) return setup @@ -33,6 +61,7 @@ def calculate_sample_bounds( : Bounds in x, y, and z directions. """ + mins = (min(samples.x), min(samples.y), min(samples.z)) maxs = (max(samples.x), max(samples.y), max(samples.z)) @@ -64,8 +93,10 @@ def make_setup_file( Returns ------- + : Contents of ARCADE setup file. """ + root = ET.fromstring("") series = ET.SubElement( root, diff --git a/src/arcade_collection/input/group_template_conditions.py b/src/arcade_collection/input/group_template_conditions.py index 78c6219..3f68681 100644 --- a/src/arcade_collection/input/group_template_conditions.py +++ b/src/arcade_collection/input/group_template_conditions.py @@ -2,6 +2,22 @@ def group_template_conditions(conditions: list[dict], max_seeds: int) -> list[dict]: + """ + Create conditions groups obeying specified max seeds for each group. + + Parameters + ---------- + conditions + List of conditions, containing a unique "key" and "seed". + max_seeds + Maximum number of total seeds in each group. + + Returns + ------- + : + List of condition groups. + """ + grouped_conditions = group_seed_ranges(conditions, max_seeds) condition_sets = group_condition_sets(grouped_conditions, max_seeds) template_conditions = [{"conditions": condition_set} for condition_set in condition_sets] @@ -9,6 +25,22 @@ def group_template_conditions(conditions: list[dict], max_seeds: int) -> list[di def group_seed_ranges(conditions: list[dict], max_seeds: int) -> list[dict]: + """ + Group conditions by continuous seed ranges. + + Parameters + ---------- + conditions + List of conditions, containing a unique "key" and "seed". + max_seeds + Maximum number of seeds in a single range. + + Returns + ------- + : + List of conditions updated with "start_seed" and "end_seed" ranges. + """ + conditions.sort(key=lambda x: (x["key"], x["seed"])) grouped_conditions = [] @@ -33,6 +65,22 @@ def group_seed_ranges(conditions: list[dict], max_seeds: int) -> list[dict]: def find_seed_ranges(seeds: list[int], max_seeds: int) -> list[tuple[int, int]]: + """ + Find continuous seed ranges, with range no larger than specified max seeds. + + Parameters + ---------- + seeds + List of seeds. + max_seeds + Maximum number of seeds in a single range. + + Returns + ------- + : + List of seeds grouped into ranges. + """ + seeds.sort() ranges = [] @@ -52,6 +100,23 @@ def find_seed_ranges(seeds: list[int], max_seeds: int) -> list[tuple[int, int]]: def group_condition_sets(conditions: list[dict], max_seeds: int) -> list[list[dict]]: + """ + Group conditions, with total seeds no larger than specified max seeds. + + Parameters + ---------- + conditions + List of conditions, containing a unique "key" with "start_seed" and + "end_seed" ranges. + max_seeds + Maximum number of seeds in a single group. + + Returns + ------- + : + List of groups of conditions. + """ + seed_count = 0 condition_set = [] condition_sets = [] diff --git a/src/arcade_collection/input/merge_region_samples.py b/src/arcade_collection/input/merge_region_samples.py index f230b0d..0d72950 100644 --- a/src/arcade_collection/input/merge_region_samples.py +++ b/src/arcade_collection/input/merge_region_samples.py @@ -7,14 +7,62 @@ def merge_region_samples( samples: dict[str, pd.DataFrame], margins: tuple[int, int, int] ) -> pd.DataFrame: + """ + Merge different region samples into single valid samples dataframe. + + The input samples are formatted as: + + .. code-block:: python + + { + "DEFAULT": (dataframe with columns = id, x, y, z), + "": (dataframe with columns = id, x, y, z), + "": (dataframe with columns = id, x, y, z), + ... + } + + The DEFAULT region is used as the superset of (x, y, z) samples; any sample + found only in a non-DEFAULT region are ignored. For a given id, there must + be at least one sample in each region. + + The output samples are formatted as: + + .. code-block:: markdown + + ┍━━━━━━┯━━━━━━━━━━┯━━━━━━━━━━┯━━━━━━━━━━┯━━━━━━━━━━┑ + │ id │ x │ y │ z │ region │ + ┝━━━━━━┿━━━━━━━━━━┿━━━━━━━━━━┿━━━━━━━━━━┿━━━━━━━━━━┥ + │ │ DEFAULT │ + │ │ + │ ... │ ... │ ... │ ... │ ... │ + │ │ + ┕━━━━━━┷━━━━━━━━━━┷━━━━━━━━━━┷━━━━━━━━━━┷━━━━━━━━━━┙ + + Samples that are found in the DEFAULT region, but not in any non-DEFAULT + region are marked as DEFAULT. Otherwise, the sample is marked with the + corresponding region. Region samples should be mutually exclusive. + + Parameters + ---------- + samples + Map of region names to region samples. + margins + Margin in the x, y, and z directions applied to sample locations. + + Returns + ------- + : + Dataframe of merged samples with applied margins. + """ + default_samples = samples["DEFAULT"] - all_samples = tranform_sample_coordinates(default_samples, margins) + all_samples = transform_sample_coordinates(default_samples, margins) regions = [key for key in samples.keys() if key != "DEFAULT"] all_region_samples = [] for region in regions: - region_samples = tranform_sample_coordinates(samples[region], margins, default_samples) + region_samples = transform_sample_coordinates(samples[region], margins, default_samples) region_samples["region"] = region all_region_samples.append(region_samples) @@ -29,7 +77,7 @@ def merge_region_samples( return valid_samples -def tranform_sample_coordinates( +def transform_sample_coordinates( samples: pd.DataFrame, margins: tuple[int, int, int], reference: Optional[pd.DataFrame] = None, @@ -51,6 +99,7 @@ def tranform_sample_coordinates( : Transformed sample cell ids and coordinates. """ + if reference is None: reference = samples @@ -84,6 +133,7 @@ def filter_valid_samples(samples: pd.DataFrame) -> pd.DataFrame: : Valid sample cell ids and coordinates. """ + if "region" in samples.columns: num_regions = len(samples.region.unique()) samples = samples.groupby("id").filter(lambda x: len(x.region.unique()) == num_regions) diff --git a/tests/arcade_collection/input/test_convert_to_cells_file.py b/tests/arcade_collection/input/test_convert_to_cells_file.py new file mode 100644 index 0000000..f029d2a --- /dev/null +++ b/tests/arcade_collection/input/test_convert_to_cells_file.py @@ -0,0 +1,442 @@ +import unittest + +import numpy as np +import pandas as pd + +from arcade_collection.input.convert_to_cells_file import ( + convert_to_cell, + convert_to_cell_region, + convert_to_cells_file, + convert_value_distribution, + filter_cell_reference, + get_cell_state, +) + +EPSILON = 1e-10 +DEFAULT_REGION_NAME = "DEFAULT" + + +def make_samples(cell_id, volume, height, region): + return pd.DataFrame( + { + "id": [cell_id] * volume, + "z": np.linspace(0, height, volume), + "region": [region] * volume, + } + ) + + +class TestConvertToCellsFile(unittest.TestCase): + def setUp(self): + self.volume_distributions = { + DEFAULT_REGION_NAME: (10, 10), + "REGION1": (10, 5), + "REGION2": (25, 10), + } + self.height_distributions = { + DEFAULT_REGION_NAME: (5, 5), + "REGION1": (4, 1), + "REGION2": (10, 5), + } + self.critical_volume_distributions = { + DEFAULT_REGION_NAME: (2, 2), + "REGION1": (2, 4), + "REGION2": (6, 1), + } + self.critical_height_distributions = { + DEFAULT_REGION_NAME: (10, 10), + "REGION1": (2, 3), + "REGION2": (8, 4), + } + self.state_thresholds = { + "STATE1_PHASE1": 0.2, + "STATE1_PHASE2": 1.5, + "STATE2_PHASE1": 2, + } + + self.reference = { + "volume": 10, + "height": 10, + "volume.REGION1": 10, + "height.REGION1": 4, + "volume.REGION2": 15, + "height.REGION2": 10, + } + + def test_convert_to_cells_file(self): + cell_ids = [10, 11, 12, 13] + volumes = [[20], [20], [15, 5], [40, 10]] + heights = [5, 5, 5, 20] + + samples = pd.concat( + [ + ( + make_samples(cell_id, volume, height, f"REGION{index + 1}") + if len(region_volumes) > 1 + else make_samples(cell_id, volume, height, None) + ) + for cell_id, height, region_volumes in zip(cell_ids, heights, volumes) + for index, volume in enumerate(region_volumes) + ] + ) + + reference = pd.DataFrame( + { + "ID": [11, 13], + "volume": [10, 10], + "height": [10, 10], + "volume.REGION1": [None, 10], + "height.REGION1": [None, 4], + "volume.REGION2": [None, 15], + "height.REGION2": [None, 10], + } + ) + + expected_cells = [ + { + "id": 1, + "parent": 0, + "pop": 1, + "age": 0, + "divisions": 0, + "state": "STATE2", + "phase": "STATE2_PHASE1", + "voxels": np.sum(volumes[0]), + "criticals": [4, 10], + }, + { + "id": 2, + "parent": 0, + "pop": 1, + "age": 0, + "divisions": 0, + "state": "STATE2", + "phase": "STATE2_PHASE1", + "voxels": np.sum(volumes[1]), + "criticals": [2, 20], + }, + { + "id": 3, + "parent": 0, + "pop": 1, + "age": 0, + "divisions": 0, + "state": "STATE2", + "phase": "STATE2_PHASE1", + "voxels": np.sum(volumes[2]), + "criticals": [4, 10], + "regions": [ + {"region": "REGION1", "voxels": volumes[2][0], "criticals": [6, 5]}, + {"region": "REGION2", "voxels": volumes[2][1], "criticals": [4, 4]}, + ], + }, + { + "id": 4, + "parent": 0, + "pop": 1, + "age": 0, + "divisions": 0, + "state": "STATE2", + "phase": "STATE2_PHASE1", + "voxels": np.sum(volumes[3]), + "criticals": [2, 20], + "regions": [ + {"region": "REGION1", "voxels": volumes[3][0], "criticals": [2, 2]}, + {"region": "REGION2", "voxels": volumes[3][1], "criticals": [5, 8]}, + ], + }, + ] + + cells = convert_to_cells_file( + samples, + reference, + self.volume_distributions, + self.height_distributions, + self.critical_volume_distributions, + self.critical_height_distributions, + self.state_thresholds, + ) + + self.assertCountEqual(expected_cells, cells) + + def test_convert_to_cell_no_reference_no_region(self): + cell_id = 2 + volume = 20 + height = 5 + samples = make_samples(cell_id, volume, height, None) + reference = {} + + expected_cell = { + "id": cell_id, + "parent": 0, + "pop": 1, + "age": 0, + "divisions": 0, + "state": "STATE2", + "phase": "STATE2_PHASE1", + "voxels": volume, + "criticals": [4, 10], + } + + cell = convert_to_cell( + cell_id, + samples, + reference, + self.volume_distributions, + self.height_distributions, + self.critical_volume_distributions, + self.critical_height_distributions, + self.state_thresholds, + ) + + self.assertDictEqual(expected_cell, cell) + + def test_convert_to_cell_with_reference_no_region(self): + cell_id = 2 + volume = 20 + height = 5 + samples = make_samples(cell_id, volume, height, None) + reference = {"volume": 10, "height": 10} + + expected_cell = { + "id": cell_id, + "parent": 0, + "pop": 1, + "age": 0, + "divisions": 0, + "state": "STATE2", + "phase": "STATE2_PHASE1", + "voxels": volume, + "criticals": [2, 20], + } + + cell = convert_to_cell( + cell_id, + samples, + reference, + self.volume_distributions, + self.height_distributions, + self.critical_volume_distributions, + self.critical_height_distributions, + self.state_thresholds, + ) + + self.assertDictEqual(expected_cell, cell) + + def test_convert_to_cell_no_reference_with_region(self): + cell_id = 2 + volumes = [15, 5] + height = 5 + samples = pd.concat( + [ + make_samples(cell_id, volume, height, f"REGION{index + 1}") + for index, volume in enumerate(volumes) + ] + ) + reference = {} + + expected_cell = { + "id": cell_id, + "parent": 0, + "pop": 1, + "age": 0, + "divisions": 0, + "state": "STATE2", + "phase": "STATE2_PHASE1", + "voxels": np.sum(volumes), + "criticals": [4, 10], + "regions": [ + {"region": "REGION1", "voxels": volumes[0], "criticals": [6, 5]}, + {"region": "REGION2", "voxels": volumes[1], "criticals": [4, 4]}, + ], + } + + cell = convert_to_cell( + cell_id, + samples, + reference, + self.volume_distributions, + self.height_distributions, + self.critical_volume_distributions, + self.critical_height_distributions, + self.state_thresholds, + ) + + self.assertDictEqual(expected_cell, cell) + + def test_convert_to_cell_with_reference_with_region(self): + cell_id = 2 + volumes = [15, 5] + height = 5 + samples = pd.concat( + [ + make_samples(cell_id, volume, height, f"REGION{index + 1}") + for index, volume in enumerate(volumes) + ] + ) + reference = { + "volume": 10, + "height": 10, + "volume.REGION1": 10, + "height.REGION1": 4, + "volume.REGION2": 15, + "height.REGION2": 10, + } + + expected_cell = { + "id": cell_id, + "parent": 0, + "pop": 1, + "age": 0, + "divisions": 0, + "state": "STATE2", + "phase": "STATE2_PHASE1", + "voxels": np.sum(volumes), + "criticals": [2, 20], + "regions": [ + {"region": "REGION1", "voxels": volumes[0], "criticals": [2, 2]}, + {"region": "REGION2", "voxels": volumes[1], "criticals": [5, 8]}, + ], + } + + cell = convert_to_cell( + cell_id, + samples, + reference, + self.volume_distributions, + self.height_distributions, + self.critical_volume_distributions, + self.critical_height_distributions, + self.state_thresholds, + ) + + self.assertDictEqual(expected_cell, cell) + + def test_convert_to_cell_region_no_reference(self): + volume = 20 + height = 5 + region_samples = pd.DataFrame({"z": np.linspace(0, height, volume)}) + reference = {} + + expected_cell_region = { + "region": DEFAULT_REGION_NAME, + "voxels": volume, + "criticals": [4, 10], + } + + cell_region = convert_to_cell_region( + DEFAULT_REGION_NAME, + region_samples, + reference, + self.volume_distributions, + self.height_distributions, + self.critical_volume_distributions, + self.critical_height_distributions, + ) + + self.assertDictEqual(expected_cell_region, cell_region) + + def test_convert_to_cell_region_with_reference(self): + volume = 20 + height = 5 + region_samples = pd.DataFrame({"z": np.linspace(0, height, volume)}) + reference = {f"volume.{DEFAULT_REGION_NAME}": 10, f"height.{DEFAULT_REGION_NAME}": 10} + + expected_cell_region = { + "region": DEFAULT_REGION_NAME, + "voxels": volume, + "criticals": [2, 20], + } + + cell_region = convert_to_cell_region( + DEFAULT_REGION_NAME, + region_samples, + reference, + self.volume_distributions, + self.height_distributions, + self.critical_volume_distributions, + self.critical_height_distributions, + ) + + self.assertDictEqual(expected_cell_region, cell_region) + + def test_get_cell_state(self): + critical_volume = 10 + threshold_fractions = { + "STATE_A": 0.2, + "STATE_B": 1.5, + "STATE_C": 2, + } + + threshold_a = threshold_fractions["STATE_A"] * critical_volume + threshold_b = threshold_fractions["STATE_B"] * critical_volume + threshold_c = threshold_fractions["STATE_C"] * critical_volume + + parameters = [ + (threshold_a - EPSILON, "STATE_A"), # below A threshold + (threshold_a, "STATE_B"), # equal A threshold + ((threshold_a + threshold_b) / 2, "STATE_B"), # between A and B thresholds + (threshold_b, "STATE_C"), # equal B threshold + ((threshold_b + threshold_c) / 2, "STATE_C"), # between B and C thresholds + (threshold_c, "STATE_C"), # equal C threshold + (threshold_c + EPSILON, "STATE_C"), # above C threshold + ] + + for volume, expected_phase in parameters: + with self.subTest(volume=volume, expected_phase=expected_phase): + phase = get_cell_state(volume, critical_volume, threshold_fractions) + self.assertEqual(expected_phase, phase) + + def test_convert_value_distribution(self): + source_distribution = (10, 6) + target_distribution = (2, 0.6) + + parameters = [ + (10, 2), # means + (4, 1.4), # one standard deviation below + (16, 2.6), # one standard deviation above + (7, 1.7), # half standard deviation below + (13, 2.3), # half standard deviation above + ] + + for source_value, expected_target_value in parameters: + with self.subTest(source_value=source_value): + target_value = convert_value_distribution( + source_value, source_distribution, target_distribution + ) + self.assertEqual(expected_target_value, target_value) + + def test_filter_cell_reference_cell_exists(self): + cell_ids = [1, 2, 3] + feature_a = [10, 20, 30] + feature_b = ["a", "b", "c"] + cell_id = 2 + index = cell_ids.index(cell_id) + + reference = pd.DataFrame({"ID": cell_ids, "FEATURE_A": feature_a, "FEATURE_B": feature_b}) + + expected_cell_reference = { + "ID": cell_id, + "FEATURE_A": feature_a[index], + "FEATURE_B": feature_b[index], + } + + cell_reference = filter_cell_reference(cell_id, reference) + + self.assertDictEqual(expected_cell_reference, cell_reference) + + def test_filter_cell_reference_cell_does_not_exist(self): + cell_ids = [1, 2, 3] + feature_a = [10, 20, 30] + feature_b = ["a", "b", "c"] + cell_id = 4 + + reference = pd.DataFrame({"ID": cell_ids, "FEATURE_A": feature_a, "FEATURE_B": feature_b}) + + cell_reference = filter_cell_reference(cell_id, reference) + + self.assertDictEqual({}, cell_reference) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/arcade_collection/input/test_convert_to_locations_file.py b/tests/arcade_collection/input/test_convert_to_locations_file.py new file mode 100644 index 0000000..e690c9b --- /dev/null +++ b/tests/arcade_collection/input/test_convert_to_locations_file.py @@ -0,0 +1,195 @@ +import unittest + +import pandas as pd + +from arcade_collection.input.convert_to_locations_file import ( + convert_to_location, + convert_to_locations_file, + get_center_voxel, + get_location_voxels, +) + + +class TestConvertToLocationsFile(unittest.TestCase): + def test_convert_to_locations_file(self): + samples = pd.DataFrame( + { + "id": [10, 10, 10, 10, 10, 11, 11, 11, 11], + "x": [0, 1, 1, 2, 2, 30, 31, 31, 32], + "y": [3, 3, 4, 5, 5, 40, 42, 42, 44], + "z": [6, 6, 7, 7, 8, 50, 51, 52, 52], + "region": [None, None, None, None, None, "A", "B", "A", "B"], + } + ) + + expected_locations = [ + { + "id": 1, + "center": (1, 4, 6), + "location": [ + { + "region": "UNDEFINED", + "voxels": [ + (0, 3, 6), + (1, 3, 6), + (1, 4, 7), + (2, 5, 7), + (2, 5, 8), + ], + } + ], + }, + { + "id": 2, + "center": (31, 42, 51), + "location": [ + { + "region": "A", + "voxels": [ + (30, 40, 50), + (31, 42, 52), + ], + }, + { + "region": "B", + "voxels": [ + (31, 42, 51), + (32, 44, 52), + ], + }, + ], + }, + ] + + locations = convert_to_locations_file(samples) + + self.assertCountEqual(expected_locations, locations) + + def test_convert_to_location_no_region(self): + cell_id = 2 + samples = pd.DataFrame( + { + "x": [0, 1, 1, 2, 2], + "y": [3, 3, 4, 5, 5], + "z": [6, 6, 7, 7, 8], + } + ) + center = (1, 4, 6) + + expected_location = { + "id": cell_id, + "center": center, + "location": [ + { + "region": "UNDEFINED", + "voxels": [ + (0, 3, 6), + (1, 3, 6), + (1, 4, 7), + (2, 5, 7), + (2, 5, 8), + ], + } + ], + } + + location = convert_to_location(cell_id, samples) + + self.assertDictEqual(expected_location, location) + + def test_convert_to_location_with_region(self): + cell_id = 2 + samples = pd.DataFrame( + { + "x": [0, 1, 1, 2, 2], + "y": [3, 3, 4, 5, 5], + "z": [6, 6, 7, 7, 8], + "region": ["A", "B", "A", "B", "A"], + } + ) + center = (1, 4, 6) + + expected_location = { + "id": cell_id, + "center": center, + "location": [ + { + "region": "A", + "voxels": [ + (0, 3, 6), + (1, 4, 7), + (2, 5, 8), + ], + }, + { + "region": "B", + "voxels": [ + (1, 3, 6), + (2, 5, 7), + ], + }, + ], + } + + location = convert_to_location(cell_id, samples) + + self.assertDictEqual(expected_location, location) + + def test_get_center_voxel(self): + parameters = [ + ([10, 12], [3, 5], [2, 4], (11, 4, 3)), # exact + ([10, 11], [3, 4], [2, 3], (10, 3, 2)), # rounded + ] + + for x, y, z, expected_center in parameters: + with self.subTest(x=x, y=y, z=z): + samples = pd.DataFrame({"x": x, "y": y, "z": z}) + center = get_center_voxel(samples) + self.assertTupleEqual(expected_center, center) + + def test_get_location_voxels_no_region(self): + region = None + samples = pd.DataFrame( + { + "x": [0, 1, 1, 2, 2], + "y": [3, 3, 4, 5, 5], + "z": [6, 6, 7, 7, 8], + } + ) + + expected_voxels = [ + (0, 3, 6), + (1, 3, 6), + (1, 4, 7), + (2, 5, 7), + (2, 5, 8), + ] + + voxels = get_location_voxels(samples, region) + + self.assertCountEqual(expected_voxels, voxels) + + def test_get_location_voxels_with_region(self): + region = "A" + samples = pd.DataFrame( + { + "x": [0, 1, 1, 2, 2], + "y": [3, 3, 4, 5, 5], + "z": [6, 6, 7, 7, 8], + "region": ["A", "B", "A", "B", "A"], + } + ) + + expected_voxels = [ + (0, 3, 6), + (1, 4, 7), + (2, 5, 8), + ] + + voxels = get_location_voxels(samples, region) + + self.assertCountEqual(expected_voxels, voxels) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/arcade_collection/input/test_generate_setup_file.py b/tests/arcade_collection/input/test_generate_setup_file.py new file mode 100644 index 0000000..951488b --- /dev/null +++ b/tests/arcade_collection/input/test_generate_setup_file.py @@ -0,0 +1,147 @@ +import unittest + +import pandas as pd + +from arcade_collection.input.generate_setup_file import ( + DEFAULT_POPULATION_ID, + calculate_sample_bounds, + generate_setup_file, + make_setup_file, +) + + +class TestGenerateSetupFile(unittest.TestCase): + def setUp(self): + self.terms = ["term_a", "term_b", "term_c"] + self.margins = [10, 20, 30] + + self.setup_template_no_region = ( + "\n" + ' \n' + " \n" + ' \n' + ' \n' + ' \n' + " \n" + " \n" + " \n" + f' \n' + " \n" + " \n" + " \n" + "" + ) + + self.setup_template_with_region = ( + "\n" + ' \n' + " \n" + ' \n' + ' \n' + ' \n' + " \n" + " \n" + " \n" + f' \n' + f' \n' + f' \n' + " \n" + " \n" + " \n" + " \n" + "" + ) + + def test_generate_setup_file_no_regions(self): + samples = pd.DataFrame( + { + "id": [1, 2], + "x": [0, 2], + "y": [3, 7], + "z": [6, 7], + } + ) + + expected_setup = self.setup_template_no_region % (25, 47, 64, 2) + + setup = generate_setup_file(samples, self.margins, self.terms) + + self.assertEqual(expected_setup, setup) + + def test_generate_setup_file_invalid_regions(self): + samples = pd.DataFrame( + { + "id": [1, 2], + "x": [0, 2], + "y": [3, 7], + "z": [6, 7], + "region": [None, None], + } + ) + + expected_setup = self.setup_template_no_region % (25, 47, 64, 2) + + setup = generate_setup_file(samples, self.margins, self.terms) + + self.assertEqual(expected_setup, setup) + + def test_generate_setup_file_with_regions(self): + samples = pd.DataFrame( + { + "id": [1, 2], + "x": [0, 2], + "y": [3, 7], + "z": [6, 7], + "region": ["A", "B"], + } + ) + + expected_setup = self.setup_template_with_region % (25, 47, 64, 2, "A", "B") + + setup = generate_setup_file(samples, self.margins, self.terms) + + self.assertEqual(expected_setup, setup) + + def test_calculate_sample_bounds(self): + samples = pd.DataFrame( + { + "x": [0, 2], + "y": [3, 7], + "z": [6, 7], + } + ) + margins = [10, 20, 30] + + expected_bounds = (25, 47, 64) + + bounds = calculate_sample_bounds(samples, margins) + + self.assertTupleEqual(expected_bounds, bounds) + + def test_make_setup_file_no_regions(self): + init = 100 + bounds = (10, 20, 30) + regions = None + + expected_setup = self.setup_template_no_region % (*bounds, init) + + setup = make_setup_file(init, bounds, self.terms, regions) + + self.assertEqual(expected_setup, setup) + + def test_make_setup_file_with_regions(self): + init = 100 + bounds = (10, 20, 30) + regions = ["REGION_A", "REGION_B"] + + expected_setup = self.setup_template_with_region % (*bounds, init, regions[0], regions[1]) + + setup = make_setup_file(init, bounds, self.terms, regions) + + self.assertEqual(expected_setup, setup) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/arcade_collection/input/test_group_template_conditions.py b/tests/arcade_collection/input/test_group_template_conditions.py new file mode 100644 index 0000000..6600256 --- /dev/null +++ b/tests/arcade_collection/input/test_group_template_conditions.py @@ -0,0 +1,126 @@ +import unittest + +from arcade_collection.input.group_template_conditions import ( + find_seed_ranges, + group_condition_sets, + group_seed_ranges, + group_template_conditions, +) + + +class TestGroupTemplateConditions(unittest.TestCase): + def test_group_template_conditions(self): + max_seeds = 3 + conditions = [ + {"key": "A", "seed": 1}, + {"key": "A", "seed": 2}, + {"key": "A", "seed": 4}, + {"key": "B", "seed": 1}, + {"key": "B", "seed": 2}, + {"key": "B", "seed": 3}, + {"key": "B", "seed": 4}, + ] + + expected_groups = [ + { + "conditions": [ + {"key": "A", "start_seed": 1, "end_seed": 2}, + {"key": "A", "start_seed": 4, "end_seed": 4}, + ] + }, + { + "conditions": [ + {"key": "B", "start_seed": 1, "end_seed": 3}, + ] + }, + { + "conditions": [ + {"key": "B", "start_seed": 4, "end_seed": 4}, + ] + }, + ] + + groups = group_template_conditions(conditions, max_seeds) + + self.assertCountEqual(expected_groups, groups) + + def test_find_seed_ranges_continuous(self): + seeds = [0, 1, 2, 3] + parameters = [ + (2, [(0, 1), (2, 3)]), # below max, equal + (3, [(0, 2), (3, 3)]), # below max, unequal + (4, [(0, 3)]), # equal to max + (5, [(0, 3)]), # above max + ] + + for max_seeds, expected_groups in parameters: + with self.subTest(max_seeds=max_seeds): + groups = find_seed_ranges(seeds, max_seeds) + self.assertCountEqual(expected_groups, groups) + + def test_find_seed_ranges_discontinuous(self): + seeds = [0, 1, 2, 3, 5, 6, 7, 8] + parameters = [ + (2, [(0, 1), (2, 3), (5, 6), (7, 8)]), # below max, equal + (3, [(0, 2), (3, 3), (5, 7), (8, 8)]), # below max, unequal + (4, [(0, 3), (5, 8)]), # equal to max + (5, [(0, 3), (5, 8)]), # above max + ] + + for max_seeds, expected_groups in parameters: + with self.subTest(max_seeds=max_seeds): + groups = find_seed_ranges(seeds, max_seeds) + self.assertCountEqual(expected_groups, groups) + + def test_group_seed_ranges(self): + max_seeds = 2 + conditions = [ + {"key": "A", "seed": 1}, + {"key": "A", "seed": 2}, + {"key": "A", "seed": 3}, + {"key": "B", "seed": 1}, + {"key": "B", "seed": 2}, + {"key": "B", "seed": 3}, + {"key": "B", "seed": 4}, + ] + + expected_groups = [ + {"key": "A", "start_seed": 1, "end_seed": 2}, + {"key": "A", "start_seed": 3, "end_seed": 3}, + {"key": "B", "start_seed": 1, "end_seed": 2}, + {"key": "B", "start_seed": 3, "end_seed": 4}, + ] + + groups = group_seed_ranges(conditions, max_seeds) + + self.assertCountEqual(expected_groups, groups) + + def test_group_condition_sets(self): + max_seeds = 3 + conditions = [ + {"key": "A", "start_seed": 1, "end_seed": 2}, + {"key": "A", "start_seed": 3, "end_seed": 3}, + {"key": "B", "start_seed": 1, "end_seed": 2}, + {"key": "B", "start_seed": 3, "end_seed": 4}, + ] + + expected_groups = [ + [ + {"key": "A", "start_seed": 1, "end_seed": 2}, + {"key": "A", "start_seed": 3, "end_seed": 3}, + ], + [ + {"key": "B", "start_seed": 1, "end_seed": 2}, + ], + [ + {"key": "B", "start_seed": 3, "end_seed": 4}, + ], + ] + + groups = group_condition_sets(conditions, max_seeds) + + self.assertCountEqual(expected_groups, groups) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/arcade_collection/input/test_merge_region_samples.py b/tests/arcade_collection/input/test_merge_region_samples.py new file mode 100644 index 0000000..391a422 --- /dev/null +++ b/tests/arcade_collection/input/test_merge_region_samples.py @@ -0,0 +1,215 @@ +import unittest + +import pandas as pd + +from arcade_collection.input.merge_region_samples import ( + filter_valid_samples, + merge_region_samples, + transform_sample_coordinates, +) + + +class TestMergeRegionSamples(unittest.TestCase): + def test_merge_region_samples_no_regions(self): + samples = { + "DEFAULT": pd.DataFrame( + { + "id": [1, 1, 1, 2, 2], + "x": [0, 1, 1, 2, 2], + "y": [3, 3, 4, 5, 5], + "z": [6, 6, 7, 7, 8], + } + ), + } + margins = (10, 20, 30) + + expected_merged = pd.DataFrame( + { + "id": [1, 1, 1, 2, 2], + "x": [11, 12, 12, 13, 13], + "y": [21, 21, 22, 23, 23], + "z": [31, 31, 32, 32, 33], + } + ) + + merged = merge_region_samples(samples, margins) + + self.assertTrue(expected_merged.equals(merged)) + + def test_merge_region_samples_with_regions_no_fill(self): + samples = { + "DEFAULT": pd.DataFrame( + { + "id": [1, 1, 1, 2, 2], + "x": [0, 1, 1, 2, 2], + "y": [3, 3, 4, 5, 5], + "z": [6, 6, 7, 7, 8], + } + ), + "REGION_A": pd.DataFrame( + {"id": [1, 1, 2], "x": [0, 1, 2], "y": [3, 4, 5], "z": [6, 7, 7]} + ), + "REGION_B": pd.DataFrame({"id": [1, 2], "x": [1, 2], "y": [3, 5], "z": [6, 8]}), + } + margins = (10, 20, 30) + + expected_merged = pd.DataFrame( + { + "id": [1, 1, 1, 2, 2], + "x": [11, 12, 12, 13, 13], + "y": [21, 21, 22, 23, 23], + "z": [31, 31, 32, 32, 33], + "region": ["REGION_A", "REGION_B", "REGION_A", "REGION_A", "REGION_B"], + } + ) + + merged = merge_region_samples(samples, margins) + + self.assertTrue(expected_merged.equals(merged)) + + def test_merge_region_samples_with_regions_with_fill(self): + samples = { + "DEFAULT": pd.DataFrame( + { + "id": [1, 1, 1, 2, 2], + "x": [0, 1, 1, 2, 2], + "y": [3, 3, 4, 5, 5], + "z": [6, 6, 7, 7, 8], + } + ), + "REGION_A": pd.DataFrame( + {"id": [1, 1, 2], "x": [0, 1, 2], "y": [3, 4, 5], "z": [6, 7, 7]} + ), + } + margins = (10, 20, 30) + + expected_merged = pd.DataFrame( + { + "id": [1, 1, 1, 2, 2], + "x": [11, 12, 12, 13, 13], + "y": [21, 21, 22, 23, 23], + "z": [31, 31, 32, 32, 33], + "region": ["REGION_A", "DEFAULT", "REGION_A", "REGION_A", "DEFAULT"], + } + ) + + merged = merge_region_samples(samples, margins) + + self.assertTrue(expected_merged.equals(merged)) + + def test_transform_sample_coordinates_no_reference(self): + samples = pd.DataFrame( + { + "id": [1, 1, 1, 2, 2], + "x": [0, 1, 1, 2, 2], + "y": [3, 3, 4, 5, 5], + "z": [6, 6, 7, 7, 8], + } + ) + margins = (10, 20, 30) + reference = None + + expected_transformed = pd.DataFrame( + { + "id": [1, 1, 1, 2, 2], + "x": [11, 12, 12, 13, 13], + "y": [21, 21, 22, 23, 23], + "z": [31, 31, 32, 32, 33], + } + ) + + transformed = transform_sample_coordinates(samples, margins, reference) + + self.assertTrue(expected_transformed.equals(transformed)) + + def test_transform_sample_coordinates_with_reference(self): + samples = pd.DataFrame( + { + "id": [1, 1, 1, 2, 2], + "x": [0, 1, 1, 2, 2], + "y": [3, 3, 4, 5, 5], + "z": [6, 6, 7, 7, 8], + } + ) + margins = (10, 20, 30) + reference = pd.DataFrame( + { + "x": [0], + "y": [1], + "z": [2], + } + ) + + expected_transformed = pd.DataFrame( + { + "id": [1, 1, 1, 2, 2], + "x": [11, 12, 12, 13, 13], + "y": [23, 23, 24, 25, 25], + "z": [35, 35, 36, 36, 37], + } + ) + + transformed = transform_sample_coordinates(samples, margins, reference) + + self.assertTrue(expected_transformed.equals(transformed)) + + def test_filter_valid_samples_no_region_all_valid(self): + samples = pd.DataFrame( + { + "x": [0, 1, 1, 2, 2], + "y": [3, 3, 4, 5, 5], + "z": [6, 6, 7, 7, 8], + } + ) + + expected_filtered = samples.copy() + + filtered = filter_valid_samples(samples) + + self.assertTrue(expected_filtered.equals(filtered)) + + def test_filter_valid_samples_with_region_all_valid(self): + samples = pd.DataFrame( + { + "id": [1, 1, 1, 2, 2], + "x": [0, 1, 1, 2, 2], + "y": [3, 3, 4, 5, 5], + "z": [6, 6, 7, 7, 8], + "region": ["A", "B", "B", "A", "B"], + } + ) + + expected_filtered = samples.copy() + + filtered = filter_valid_samples(samples) + + self.assertTrue(expected_filtered.equals(filtered)) + + def test_filter_valid_samples_sample_outside_region(self): + samples = pd.DataFrame( + { + "id": [1, 1, 1, 2, 2], + "x": [0, 1, 1, 2, 2], + "y": [3, 3, 4, 5, 5], + "z": [6, 6, 7, 7, 8], + "region": ["A", "B", "B", "A", "A"], + } + ) + + expected_filtered = pd.DataFrame( + { + "id": [1, 1, 1], + "x": [0, 1, 1], + "y": [3, 3, 4], + "z": [6, 6, 7], + "region": ["A", "B", "B"], + } + ) + + filtered = filter_valid_samples(samples) + + self.assertTrue(expected_filtered.equals(filtered)) + + +if __name__ == "__main__": + unittest.main()