diff --git a/src/arcade_collection/__main__.py b/src/arcade_collection/__main__.py
index c552b21..e69de29 100644
--- a/src/arcade_collection/__main__.py
+++ b/src/arcade_collection/__main__.py
@@ -1,2 +0,0 @@
-if __name__ == "__main__":
- print("hello world")
diff --git a/src/arcade_collection/input/convert_to_cells_file.py b/src/arcade_collection/input/convert_to_cells_file.py
index b3261d3..5c75155 100644
--- a/src/arcade_collection/input/convert_to_cells_file.py
+++ b/src/arcade_collection/input/convert_to_cells_file.py
@@ -10,6 +10,44 @@ def convert_to_cells_file(
critical_height_distributions: dict[str, tuple[float, float]],
state_thresholds: dict[str, float],
) -> list[dict]:
+ """
+ Convert all samples to cell objects.
+
+ For each cell id in samples, current volume and height are rescaled to
+ critical volume and critical height based on distribution means and standard
+ deviations. If reference volume and/or height exist for the cell id, those
+ values are used as the current values to be rescaled. Otherwise, current
+ volume is calculated from the number of voxel samples and current height is
+ calculated from the range of voxel coordinates along the z axis.
+
+ Initial cell state and cell state phase for each cell are estimated based on
+ state thresholds, the current cell volume, and the critical cell volume.
+
+ Cell object ids are reindexed starting with cell id 1.
+
+ Parameters
+ ----------
+ samples
+ Sample cell ids and coordinates.
+ reference
+ Reference values for volumes and heights.
+ volume_distributions
+ Map of volume means and standard deviations.
+ height_distributions
+ Map of height means and standard deviations.
+ critical_volume_distributions
+ Map of critical volume means and standard deviations.
+ critical_height_distributions
+ Map of critical height means and standard deviations.
+ state_thresholds
+ Critical volume fractions defining threshold between states.
+
+ Returns
+ -------
+ :
+ List of cell objects formatted for ARCADE.
+ """
+
cells: list[dict] = []
samples_by_id = samples.groupby("id")
@@ -42,30 +80,43 @@ def convert_to_cell(
state_thresholds: dict[str, float],
) -> dict:
"""
- Convert samples to ARCADE .CELLS json format.
+ Convert samples to cell object.
+
+ Current volume and height are rescaled to critical volume and critical
+ height based on distribution means and standard deviations. If reference
+ volume and/or height are provided (under the "DEFAULT" key), those values
+ are used as the current values to be rescaled. Otherwise, current volume is
+ calculated from the number of voxel samples and current height is calculated
+ from the range of voxel coordinates along the z axis.
+
+ Initial cell state and cell state phase are estimated based on state
+ thresholds, the current cell volume, and the critical cell volume.
Parameters
----------
cell_id
Unique cell id.
samples
- Sample cell ids and coordinates.
+ Sample coordinates for a single object.
reference
- Reference data for conversion.
- volume_distribution
- Average and standard deviation of volume distributions.
- height_distribution
- Average and standard deviation of height distributions.
- critical_volume_distribution
- Average and standard deviation of critical volume distributions.
- critical_height_distribution
- Average and standard deviation of critical height distributions.
+ Reference data for cell.
+ volume_distributions
+ Map of volume means and standard deviations.
+ height_distributions
+ Map of height means and standard deviations.
+ critical_volume_distributions
+ Map of critical volume means and standard deviations.
+ critical_height_distributions
+ Map of critical height means and standard deviations.
+ state_thresholds
+ Critical volume fractions defining threshold between states.
Returns
-------
:
- Dictionary in ARCADE .CELLS json format.
+ Cell object formatted for ARCADE.
"""
+
volume = len(samples)
height = samples.z.max() - samples.z.min()
@@ -95,7 +146,7 @@ def convert_to_cell(
"criticals": [critical_volume, critical_height],
}
- if "region" in samples.columns:
+ if "region" in samples.columns and not samples["region"].isnull().all():
regions = [
convert_to_cell_region(
region,
@@ -122,6 +173,39 @@ def convert_to_cell_region(
critical_volume_distributions: dict[str, tuple[float, float]],
critical_height_distributions: dict[str, tuple[float, float]],
) -> dict:
+ """
+ Convert region samples to cell region object.
+
+ Current region volume and height are rescaled to critical volume and
+ critical height based on distribution means and standard deviations. If
+ reference region volume and/or height are provided, those values are used as
+ the current values to be rescaled. Otherwise, current region volume is
+ calculated from the number of voxel samples and current region height is
+ calculated from the range of voxel coordinates along the z axis.
+
+ Parameters
+ ----------
+ region
+ Region name.
+ region_samples
+ Sample coordinates for region of a single object.
+ reference
+ Reference data for cell region.
+ volume_distributions
+ Map of volume means and standard deviations.
+ height_distributions
+ Map of height means and standard deviations.
+ critical_volume_distributions
+ Map of critical volume means and standard deviations.
+ critical_height_distributions
+ Map of critical height means and standard deviations.
+
+ Returns
+ -------
+ :
+ Cell region object formatted for ARCADE.
+ """
+
region_volume = len(region_samples)
region_height = region_samples.z.max() - region_samples.z.min()
@@ -152,12 +236,14 @@ def get_cell_state(
"""
Estimates cell state based on cell volume.
- The threshold fractions dictionary defines the monotonic thresholds
- between different cell states.
- For a given volume v, critical volume V, and states X1, X2, ..., XN with
- corresponding, monotonic threshold fractions f1, f2, ..., fN, a cell is
- assigned state Xi such that [f(i - 1) * V] <= v < [fi * V].
+ The threshold fractions dictionary defines the monotonic thresholds between
+ different cell states. For a given volume v, critical volume V, and states
+ X1, X2, ..., XN with corresponding, monotonic threshold fractions f1, f2,
+ ..., fN, a cell is assigned state Xi such that [f(i - 1) * V] <= v < [fi *
+ V].
+
Cells with v < f1 * V are assigned state X1.
+
Cells with v > fN * V are assigned state XN.
Parameters
@@ -174,6 +260,7 @@ def get_cell_state(
:
Cell state.
"""
+
thresholds = [fraction * critical_volume for fraction in threshold_fractions.values()]
states = list(threshold_fractions.keys())
@@ -203,6 +290,7 @@ def convert_value_distribution(
:
Estimated critical value.
"""
+
source_avg, source_std = source_distribution
target_avg, target_std = target_distribution
z_scored_value = (value - source_avg) / source_std
@@ -226,6 +314,7 @@ def filter_cell_reference(cell_id: int, reference: pd.DataFrame) -> dict:
:
Reference data for given cell id.
"""
+
cell_reference = reference[reference["ID"] == cell_id].squeeze()
cell_reference = cell_reference.to_dict() if not cell_reference.empty else {}
return cell_reference
diff --git a/src/arcade_collection/input/convert_to_locations_file.py b/src/arcade_collection/input/convert_to_locations_file.py
index f31d0b9..acff995 100644
--- a/src/arcade_collection/input/convert_to_locations_file.py
+++ b/src/arcade_collection/input/convert_to_locations_file.py
@@ -4,6 +4,20 @@
def convert_to_locations_file(samples: pd.DataFrame) -> list[dict]:
+ """
+ Convert all samples to location objects.
+
+ Parameters
+ ----------
+ samples
+ Sample cell ids and coordinates.
+
+ Returns
+ -------
+ :
+ List of location objects formatted for ARCADE.
+ """
+
locations: list[dict] = []
samples_by_id = samples.groupby("id")
@@ -15,23 +29,24 @@ def convert_to_locations_file(samples: pd.DataFrame) -> list[dict]:
def convert_to_location(cell_id: int, samples: pd.DataFrame) -> dict:
"""
- Convert samples to ARCADE .LOCATIONS json format.
+ Convert samples to location object.
Parameters
----------
cell_id
Unique cell id.
samples
- Sample cell ids and coordinates.
+ Sample coordinates for a single object.
Returns
-------
:
- Dictionary in ARCADE .LOCATIONS json format.
+ Location object formatted for ARCADE.
"""
+
center = get_center_voxel(samples)
- if "region" in samples.columns:
+ if "region" in samples.columns and not samples["region"].isnull().all():
voxels = [
{"region": region, "voxels": get_location_voxels(samples, region)}
for region in samples["region"].unique()
@@ -44,6 +59,7 @@ def convert_to_location(cell_id: int, samples: pd.DataFrame) -> dict:
"center": center,
"location": voxels,
}
+
return location
@@ -61,6 +77,7 @@ def get_center_voxel(samples: pd.DataFrame) -> tuple[int, int, int]:
:
Center voxel.
"""
+
center_x = int(samples["x"].mean())
center_y = int(samples["y"].mean())
center_z = int(samples["z"].mean())
@@ -86,6 +103,7 @@ def get_location_voxels(
:
List of voxel coordinates.
"""
+
if region is not None:
region_samples = samples[samples["region"] == region]
voxels_x = region_samples["x"]
diff --git a/src/arcade_collection/input/generate_setup_file.py b/src/arcade_collection/input/generate_setup_file.py
index 09153a0..9f54810 100644
--- a/src/arcade_collection/input/generate_setup_file.py
+++ b/src/arcade_collection/input/generate_setup_file.py
@@ -4,14 +4,42 @@
import numpy as np
import pandas as pd
+DEFAULT_POPULATION_ID = "X"
+"""Default population ID used in setup file."""
+
def generate_setup_file(
- samples: pd.DataFrame, margins: tuple[int, int, int], potts_terms: list[str]
+ samples: pd.DataFrame, margins: tuple[int, int, int], terms: list[str]
) -> str:
+ """
+ Create ARCADE setup file from samples, margins, and CPM Hamiltonian terms.
+
+ Initial number of cells is determined by number of unique ids in samples.
+ Regions are included if samples contains valid regions.
+
+ Parameters
+ ----------
+ samples
+ Sample cell ids and coordinates.
+ margins
+ Margin size in x, y, and z directions.
+ terms
+ List of Potts Hamiltonian terms for setup file.
+
+ Returns
+ -------
+ :
+ Contents of ARCADE setup file.
+ """
+
init = len(samples["id"].unique())
bounds = calculate_sample_bounds(samples, margins)
- regions = samples["regions"].unique() if "regions" in samples else None
- setup = make_setup_file(init, bounds, potts_terms, regions)
+ regions = (
+ samples["region"].unique()
+ if "region" in samples.columns and not samples["region"].isnull().all()
+ else None
+ )
+ setup = make_setup_file(init, bounds, terms, regions)
return setup
@@ -33,6 +61,7 @@ def calculate_sample_bounds(
:
Bounds in x, y, and z directions.
"""
+
mins = (min(samples.x), min(samples.y), min(samples.z))
maxs = (max(samples.x), max(samples.y), max(samples.z))
@@ -64,8 +93,10 @@ def make_setup_file(
Returns
-------
+ :
Contents of ARCADE setup file.
"""
+
root = ET.fromstring("")
series = ET.SubElement(
root,
diff --git a/src/arcade_collection/input/group_template_conditions.py b/src/arcade_collection/input/group_template_conditions.py
index 78c6219..3f68681 100644
--- a/src/arcade_collection/input/group_template_conditions.py
+++ b/src/arcade_collection/input/group_template_conditions.py
@@ -2,6 +2,22 @@
def group_template_conditions(conditions: list[dict], max_seeds: int) -> list[dict]:
+ """
+ Create conditions groups obeying specified max seeds for each group.
+
+ Parameters
+ ----------
+ conditions
+ List of conditions, containing a unique "key" and "seed".
+ max_seeds
+ Maximum number of total seeds in each group.
+
+ Returns
+ -------
+ :
+ List of condition groups.
+ """
+
grouped_conditions = group_seed_ranges(conditions, max_seeds)
condition_sets = group_condition_sets(grouped_conditions, max_seeds)
template_conditions = [{"conditions": condition_set} for condition_set in condition_sets]
@@ -9,6 +25,22 @@ def group_template_conditions(conditions: list[dict], max_seeds: int) -> list[di
def group_seed_ranges(conditions: list[dict], max_seeds: int) -> list[dict]:
+ """
+ Group conditions by continuous seed ranges.
+
+ Parameters
+ ----------
+ conditions
+ List of conditions, containing a unique "key" and "seed".
+ max_seeds
+ Maximum number of seeds in a single range.
+
+ Returns
+ -------
+ :
+ List of conditions updated with "start_seed" and "end_seed" ranges.
+ """
+
conditions.sort(key=lambda x: (x["key"], x["seed"]))
grouped_conditions = []
@@ -33,6 +65,22 @@ def group_seed_ranges(conditions: list[dict], max_seeds: int) -> list[dict]:
def find_seed_ranges(seeds: list[int], max_seeds: int) -> list[tuple[int, int]]:
+ """
+ Find continuous seed ranges, with range no larger than specified max seeds.
+
+ Parameters
+ ----------
+ seeds
+ List of seeds.
+ max_seeds
+ Maximum number of seeds in a single range.
+
+ Returns
+ -------
+ :
+ List of seeds grouped into ranges.
+ """
+
seeds.sort()
ranges = []
@@ -52,6 +100,23 @@ def find_seed_ranges(seeds: list[int], max_seeds: int) -> list[tuple[int, int]]:
def group_condition_sets(conditions: list[dict], max_seeds: int) -> list[list[dict]]:
+ """
+ Group conditions, with total seeds no larger than specified max seeds.
+
+ Parameters
+ ----------
+ conditions
+ List of conditions, containing a unique "key" with "start_seed" and
+ "end_seed" ranges.
+ max_seeds
+ Maximum number of seeds in a single group.
+
+ Returns
+ -------
+ :
+ List of groups of conditions.
+ """
+
seed_count = 0
condition_set = []
condition_sets = []
diff --git a/src/arcade_collection/input/merge_region_samples.py b/src/arcade_collection/input/merge_region_samples.py
index f230b0d..0d72950 100644
--- a/src/arcade_collection/input/merge_region_samples.py
+++ b/src/arcade_collection/input/merge_region_samples.py
@@ -7,14 +7,62 @@
def merge_region_samples(
samples: dict[str, pd.DataFrame], margins: tuple[int, int, int]
) -> pd.DataFrame:
+ """
+ Merge different region samples into single valid samples dataframe.
+
+ The input samples are formatted as:
+
+ .. code-block:: python
+
+ {
+ "DEFAULT": (dataframe with columns = id, x, y, z),
+ "": (dataframe with columns = id, x, y, z),
+ "": (dataframe with columns = id, x, y, z),
+ ...
+ }
+
+ The DEFAULT region is used as the superset of (x, y, z) samples; any sample
+ found only in a non-DEFAULT region are ignored. For a given id, there must
+ be at least one sample in each region.
+
+ The output samples are formatted as:
+
+ .. code-block:: markdown
+
+ ┍━━━━━━┯━━━━━━━━━━┯━━━━━━━━━━┯━━━━━━━━━━┯━━━━━━━━━━┑
+ │ id │ x │ y │ z │ region │
+ ┝━━━━━━┿━━━━━━━━━━┿━━━━━━━━━━┿━━━━━━━━━━┿━━━━━━━━━━┥
+ │ │ │ │ │ DEFAULT │
+ │ │ │ │ │ │
+ │ ... │ ... │ ... │ ... │ ... │
+ │ │ │ │ │ │
+ ┕━━━━━━┷━━━━━━━━━━┷━━━━━━━━━━┷━━━━━━━━━━┷━━━━━━━━━━┙
+
+ Samples that are found in the DEFAULT region, but not in any non-DEFAULT
+ region are marked as DEFAULT. Otherwise, the sample is marked with the
+ corresponding region. Region samples should be mutually exclusive.
+
+ Parameters
+ ----------
+ samples
+ Map of region names to region samples.
+ margins
+ Margin in the x, y, and z directions applied to sample locations.
+
+ Returns
+ -------
+ :
+ Dataframe of merged samples with applied margins.
+ """
+
default_samples = samples["DEFAULT"]
- all_samples = tranform_sample_coordinates(default_samples, margins)
+ all_samples = transform_sample_coordinates(default_samples, margins)
regions = [key for key in samples.keys() if key != "DEFAULT"]
all_region_samples = []
for region in regions:
- region_samples = tranform_sample_coordinates(samples[region], margins, default_samples)
+ region_samples = transform_sample_coordinates(samples[region], margins, default_samples)
region_samples["region"] = region
all_region_samples.append(region_samples)
@@ -29,7 +77,7 @@ def merge_region_samples(
return valid_samples
-def tranform_sample_coordinates(
+def transform_sample_coordinates(
samples: pd.DataFrame,
margins: tuple[int, int, int],
reference: Optional[pd.DataFrame] = None,
@@ -51,6 +99,7 @@ def tranform_sample_coordinates(
:
Transformed sample cell ids and coordinates.
"""
+
if reference is None:
reference = samples
@@ -84,6 +133,7 @@ def filter_valid_samples(samples: pd.DataFrame) -> pd.DataFrame:
:
Valid sample cell ids and coordinates.
"""
+
if "region" in samples.columns:
num_regions = len(samples.region.unique())
samples = samples.groupby("id").filter(lambda x: len(x.region.unique()) == num_regions)
diff --git a/tests/arcade_collection/input/test_convert_to_cells_file.py b/tests/arcade_collection/input/test_convert_to_cells_file.py
new file mode 100644
index 0000000..f029d2a
--- /dev/null
+++ b/tests/arcade_collection/input/test_convert_to_cells_file.py
@@ -0,0 +1,442 @@
+import unittest
+
+import numpy as np
+import pandas as pd
+
+from arcade_collection.input.convert_to_cells_file import (
+ convert_to_cell,
+ convert_to_cell_region,
+ convert_to_cells_file,
+ convert_value_distribution,
+ filter_cell_reference,
+ get_cell_state,
+)
+
+EPSILON = 1e-10
+DEFAULT_REGION_NAME = "DEFAULT"
+
+
+def make_samples(cell_id, volume, height, region):
+ return pd.DataFrame(
+ {
+ "id": [cell_id] * volume,
+ "z": np.linspace(0, height, volume),
+ "region": [region] * volume,
+ }
+ )
+
+
+class TestConvertToCellsFile(unittest.TestCase):
+ def setUp(self):
+ self.volume_distributions = {
+ DEFAULT_REGION_NAME: (10, 10),
+ "REGION1": (10, 5),
+ "REGION2": (25, 10),
+ }
+ self.height_distributions = {
+ DEFAULT_REGION_NAME: (5, 5),
+ "REGION1": (4, 1),
+ "REGION2": (10, 5),
+ }
+ self.critical_volume_distributions = {
+ DEFAULT_REGION_NAME: (2, 2),
+ "REGION1": (2, 4),
+ "REGION2": (6, 1),
+ }
+ self.critical_height_distributions = {
+ DEFAULT_REGION_NAME: (10, 10),
+ "REGION1": (2, 3),
+ "REGION2": (8, 4),
+ }
+ self.state_thresholds = {
+ "STATE1_PHASE1": 0.2,
+ "STATE1_PHASE2": 1.5,
+ "STATE2_PHASE1": 2,
+ }
+
+ self.reference = {
+ "volume": 10,
+ "height": 10,
+ "volume.REGION1": 10,
+ "height.REGION1": 4,
+ "volume.REGION2": 15,
+ "height.REGION2": 10,
+ }
+
+ def test_convert_to_cells_file(self):
+ cell_ids = [10, 11, 12, 13]
+ volumes = [[20], [20], [15, 5], [40, 10]]
+ heights = [5, 5, 5, 20]
+
+ samples = pd.concat(
+ [
+ (
+ make_samples(cell_id, volume, height, f"REGION{index + 1}")
+ if len(region_volumes) > 1
+ else make_samples(cell_id, volume, height, None)
+ )
+ for cell_id, height, region_volumes in zip(cell_ids, heights, volumes)
+ for index, volume in enumerate(region_volumes)
+ ]
+ )
+
+ reference = pd.DataFrame(
+ {
+ "ID": [11, 13],
+ "volume": [10, 10],
+ "height": [10, 10],
+ "volume.REGION1": [None, 10],
+ "height.REGION1": [None, 4],
+ "volume.REGION2": [None, 15],
+ "height.REGION2": [None, 10],
+ }
+ )
+
+ expected_cells = [
+ {
+ "id": 1,
+ "parent": 0,
+ "pop": 1,
+ "age": 0,
+ "divisions": 0,
+ "state": "STATE2",
+ "phase": "STATE2_PHASE1",
+ "voxels": np.sum(volumes[0]),
+ "criticals": [4, 10],
+ },
+ {
+ "id": 2,
+ "parent": 0,
+ "pop": 1,
+ "age": 0,
+ "divisions": 0,
+ "state": "STATE2",
+ "phase": "STATE2_PHASE1",
+ "voxels": np.sum(volumes[1]),
+ "criticals": [2, 20],
+ },
+ {
+ "id": 3,
+ "parent": 0,
+ "pop": 1,
+ "age": 0,
+ "divisions": 0,
+ "state": "STATE2",
+ "phase": "STATE2_PHASE1",
+ "voxels": np.sum(volumes[2]),
+ "criticals": [4, 10],
+ "regions": [
+ {"region": "REGION1", "voxels": volumes[2][0], "criticals": [6, 5]},
+ {"region": "REGION2", "voxels": volumes[2][1], "criticals": [4, 4]},
+ ],
+ },
+ {
+ "id": 4,
+ "parent": 0,
+ "pop": 1,
+ "age": 0,
+ "divisions": 0,
+ "state": "STATE2",
+ "phase": "STATE2_PHASE1",
+ "voxels": np.sum(volumes[3]),
+ "criticals": [2, 20],
+ "regions": [
+ {"region": "REGION1", "voxels": volumes[3][0], "criticals": [2, 2]},
+ {"region": "REGION2", "voxels": volumes[3][1], "criticals": [5, 8]},
+ ],
+ },
+ ]
+
+ cells = convert_to_cells_file(
+ samples,
+ reference,
+ self.volume_distributions,
+ self.height_distributions,
+ self.critical_volume_distributions,
+ self.critical_height_distributions,
+ self.state_thresholds,
+ )
+
+ self.assertCountEqual(expected_cells, cells)
+
+ def test_convert_to_cell_no_reference_no_region(self):
+ cell_id = 2
+ volume = 20
+ height = 5
+ samples = make_samples(cell_id, volume, height, None)
+ reference = {}
+
+ expected_cell = {
+ "id": cell_id,
+ "parent": 0,
+ "pop": 1,
+ "age": 0,
+ "divisions": 0,
+ "state": "STATE2",
+ "phase": "STATE2_PHASE1",
+ "voxels": volume,
+ "criticals": [4, 10],
+ }
+
+ cell = convert_to_cell(
+ cell_id,
+ samples,
+ reference,
+ self.volume_distributions,
+ self.height_distributions,
+ self.critical_volume_distributions,
+ self.critical_height_distributions,
+ self.state_thresholds,
+ )
+
+ self.assertDictEqual(expected_cell, cell)
+
+ def test_convert_to_cell_with_reference_no_region(self):
+ cell_id = 2
+ volume = 20
+ height = 5
+ samples = make_samples(cell_id, volume, height, None)
+ reference = {"volume": 10, "height": 10}
+
+ expected_cell = {
+ "id": cell_id,
+ "parent": 0,
+ "pop": 1,
+ "age": 0,
+ "divisions": 0,
+ "state": "STATE2",
+ "phase": "STATE2_PHASE1",
+ "voxels": volume,
+ "criticals": [2, 20],
+ }
+
+ cell = convert_to_cell(
+ cell_id,
+ samples,
+ reference,
+ self.volume_distributions,
+ self.height_distributions,
+ self.critical_volume_distributions,
+ self.critical_height_distributions,
+ self.state_thresholds,
+ )
+
+ self.assertDictEqual(expected_cell, cell)
+
+ def test_convert_to_cell_no_reference_with_region(self):
+ cell_id = 2
+ volumes = [15, 5]
+ height = 5
+ samples = pd.concat(
+ [
+ make_samples(cell_id, volume, height, f"REGION{index + 1}")
+ for index, volume in enumerate(volumes)
+ ]
+ )
+ reference = {}
+
+ expected_cell = {
+ "id": cell_id,
+ "parent": 0,
+ "pop": 1,
+ "age": 0,
+ "divisions": 0,
+ "state": "STATE2",
+ "phase": "STATE2_PHASE1",
+ "voxels": np.sum(volumes),
+ "criticals": [4, 10],
+ "regions": [
+ {"region": "REGION1", "voxels": volumes[0], "criticals": [6, 5]},
+ {"region": "REGION2", "voxels": volumes[1], "criticals": [4, 4]},
+ ],
+ }
+
+ cell = convert_to_cell(
+ cell_id,
+ samples,
+ reference,
+ self.volume_distributions,
+ self.height_distributions,
+ self.critical_volume_distributions,
+ self.critical_height_distributions,
+ self.state_thresholds,
+ )
+
+ self.assertDictEqual(expected_cell, cell)
+
+ def test_convert_to_cell_with_reference_with_region(self):
+ cell_id = 2
+ volumes = [15, 5]
+ height = 5
+ samples = pd.concat(
+ [
+ make_samples(cell_id, volume, height, f"REGION{index + 1}")
+ for index, volume in enumerate(volumes)
+ ]
+ )
+ reference = {
+ "volume": 10,
+ "height": 10,
+ "volume.REGION1": 10,
+ "height.REGION1": 4,
+ "volume.REGION2": 15,
+ "height.REGION2": 10,
+ }
+
+ expected_cell = {
+ "id": cell_id,
+ "parent": 0,
+ "pop": 1,
+ "age": 0,
+ "divisions": 0,
+ "state": "STATE2",
+ "phase": "STATE2_PHASE1",
+ "voxels": np.sum(volumes),
+ "criticals": [2, 20],
+ "regions": [
+ {"region": "REGION1", "voxels": volumes[0], "criticals": [2, 2]},
+ {"region": "REGION2", "voxels": volumes[1], "criticals": [5, 8]},
+ ],
+ }
+
+ cell = convert_to_cell(
+ cell_id,
+ samples,
+ reference,
+ self.volume_distributions,
+ self.height_distributions,
+ self.critical_volume_distributions,
+ self.critical_height_distributions,
+ self.state_thresholds,
+ )
+
+ self.assertDictEqual(expected_cell, cell)
+
+ def test_convert_to_cell_region_no_reference(self):
+ volume = 20
+ height = 5
+ region_samples = pd.DataFrame({"z": np.linspace(0, height, volume)})
+ reference = {}
+
+ expected_cell_region = {
+ "region": DEFAULT_REGION_NAME,
+ "voxels": volume,
+ "criticals": [4, 10],
+ }
+
+ cell_region = convert_to_cell_region(
+ DEFAULT_REGION_NAME,
+ region_samples,
+ reference,
+ self.volume_distributions,
+ self.height_distributions,
+ self.critical_volume_distributions,
+ self.critical_height_distributions,
+ )
+
+ self.assertDictEqual(expected_cell_region, cell_region)
+
+ def test_convert_to_cell_region_with_reference(self):
+ volume = 20
+ height = 5
+ region_samples = pd.DataFrame({"z": np.linspace(0, height, volume)})
+ reference = {f"volume.{DEFAULT_REGION_NAME}": 10, f"height.{DEFAULT_REGION_NAME}": 10}
+
+ expected_cell_region = {
+ "region": DEFAULT_REGION_NAME,
+ "voxels": volume,
+ "criticals": [2, 20],
+ }
+
+ cell_region = convert_to_cell_region(
+ DEFAULT_REGION_NAME,
+ region_samples,
+ reference,
+ self.volume_distributions,
+ self.height_distributions,
+ self.critical_volume_distributions,
+ self.critical_height_distributions,
+ )
+
+ self.assertDictEqual(expected_cell_region, cell_region)
+
+ def test_get_cell_state(self):
+ critical_volume = 10
+ threshold_fractions = {
+ "STATE_A": 0.2,
+ "STATE_B": 1.5,
+ "STATE_C": 2,
+ }
+
+ threshold_a = threshold_fractions["STATE_A"] * critical_volume
+ threshold_b = threshold_fractions["STATE_B"] * critical_volume
+ threshold_c = threshold_fractions["STATE_C"] * critical_volume
+
+ parameters = [
+ (threshold_a - EPSILON, "STATE_A"), # below A threshold
+ (threshold_a, "STATE_B"), # equal A threshold
+ ((threshold_a + threshold_b) / 2, "STATE_B"), # between A and B thresholds
+ (threshold_b, "STATE_C"), # equal B threshold
+ ((threshold_b + threshold_c) / 2, "STATE_C"), # between B and C thresholds
+ (threshold_c, "STATE_C"), # equal C threshold
+ (threshold_c + EPSILON, "STATE_C"), # above C threshold
+ ]
+
+ for volume, expected_phase in parameters:
+ with self.subTest(volume=volume, expected_phase=expected_phase):
+ phase = get_cell_state(volume, critical_volume, threshold_fractions)
+ self.assertEqual(expected_phase, phase)
+
+ def test_convert_value_distribution(self):
+ source_distribution = (10, 6)
+ target_distribution = (2, 0.6)
+
+ parameters = [
+ (10, 2), # means
+ (4, 1.4), # one standard deviation below
+ (16, 2.6), # one standard deviation above
+ (7, 1.7), # half standard deviation below
+ (13, 2.3), # half standard deviation above
+ ]
+
+ for source_value, expected_target_value in parameters:
+ with self.subTest(source_value=source_value):
+ target_value = convert_value_distribution(
+ source_value, source_distribution, target_distribution
+ )
+ self.assertEqual(expected_target_value, target_value)
+
+ def test_filter_cell_reference_cell_exists(self):
+ cell_ids = [1, 2, 3]
+ feature_a = [10, 20, 30]
+ feature_b = ["a", "b", "c"]
+ cell_id = 2
+ index = cell_ids.index(cell_id)
+
+ reference = pd.DataFrame({"ID": cell_ids, "FEATURE_A": feature_a, "FEATURE_B": feature_b})
+
+ expected_cell_reference = {
+ "ID": cell_id,
+ "FEATURE_A": feature_a[index],
+ "FEATURE_B": feature_b[index],
+ }
+
+ cell_reference = filter_cell_reference(cell_id, reference)
+
+ self.assertDictEqual(expected_cell_reference, cell_reference)
+
+ def test_filter_cell_reference_cell_does_not_exist(self):
+ cell_ids = [1, 2, 3]
+ feature_a = [10, 20, 30]
+ feature_b = ["a", "b", "c"]
+ cell_id = 4
+
+ reference = pd.DataFrame({"ID": cell_ids, "FEATURE_A": feature_a, "FEATURE_B": feature_b})
+
+ cell_reference = filter_cell_reference(cell_id, reference)
+
+ self.assertDictEqual({}, cell_reference)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/arcade_collection/input/test_convert_to_locations_file.py b/tests/arcade_collection/input/test_convert_to_locations_file.py
new file mode 100644
index 0000000..e690c9b
--- /dev/null
+++ b/tests/arcade_collection/input/test_convert_to_locations_file.py
@@ -0,0 +1,195 @@
+import unittest
+
+import pandas as pd
+
+from arcade_collection.input.convert_to_locations_file import (
+ convert_to_location,
+ convert_to_locations_file,
+ get_center_voxel,
+ get_location_voxels,
+)
+
+
+class TestConvertToLocationsFile(unittest.TestCase):
+ def test_convert_to_locations_file(self):
+ samples = pd.DataFrame(
+ {
+ "id": [10, 10, 10, 10, 10, 11, 11, 11, 11],
+ "x": [0, 1, 1, 2, 2, 30, 31, 31, 32],
+ "y": [3, 3, 4, 5, 5, 40, 42, 42, 44],
+ "z": [6, 6, 7, 7, 8, 50, 51, 52, 52],
+ "region": [None, None, None, None, None, "A", "B", "A", "B"],
+ }
+ )
+
+ expected_locations = [
+ {
+ "id": 1,
+ "center": (1, 4, 6),
+ "location": [
+ {
+ "region": "UNDEFINED",
+ "voxels": [
+ (0, 3, 6),
+ (1, 3, 6),
+ (1, 4, 7),
+ (2, 5, 7),
+ (2, 5, 8),
+ ],
+ }
+ ],
+ },
+ {
+ "id": 2,
+ "center": (31, 42, 51),
+ "location": [
+ {
+ "region": "A",
+ "voxels": [
+ (30, 40, 50),
+ (31, 42, 52),
+ ],
+ },
+ {
+ "region": "B",
+ "voxels": [
+ (31, 42, 51),
+ (32, 44, 52),
+ ],
+ },
+ ],
+ },
+ ]
+
+ locations = convert_to_locations_file(samples)
+
+ self.assertCountEqual(expected_locations, locations)
+
+ def test_convert_to_location_no_region(self):
+ cell_id = 2
+ samples = pd.DataFrame(
+ {
+ "x": [0, 1, 1, 2, 2],
+ "y": [3, 3, 4, 5, 5],
+ "z": [6, 6, 7, 7, 8],
+ }
+ )
+ center = (1, 4, 6)
+
+ expected_location = {
+ "id": cell_id,
+ "center": center,
+ "location": [
+ {
+ "region": "UNDEFINED",
+ "voxels": [
+ (0, 3, 6),
+ (1, 3, 6),
+ (1, 4, 7),
+ (2, 5, 7),
+ (2, 5, 8),
+ ],
+ }
+ ],
+ }
+
+ location = convert_to_location(cell_id, samples)
+
+ self.assertDictEqual(expected_location, location)
+
+ def test_convert_to_location_with_region(self):
+ cell_id = 2
+ samples = pd.DataFrame(
+ {
+ "x": [0, 1, 1, 2, 2],
+ "y": [3, 3, 4, 5, 5],
+ "z": [6, 6, 7, 7, 8],
+ "region": ["A", "B", "A", "B", "A"],
+ }
+ )
+ center = (1, 4, 6)
+
+ expected_location = {
+ "id": cell_id,
+ "center": center,
+ "location": [
+ {
+ "region": "A",
+ "voxels": [
+ (0, 3, 6),
+ (1, 4, 7),
+ (2, 5, 8),
+ ],
+ },
+ {
+ "region": "B",
+ "voxels": [
+ (1, 3, 6),
+ (2, 5, 7),
+ ],
+ },
+ ],
+ }
+
+ location = convert_to_location(cell_id, samples)
+
+ self.assertDictEqual(expected_location, location)
+
+ def test_get_center_voxel(self):
+ parameters = [
+ ([10, 12], [3, 5], [2, 4], (11, 4, 3)), # exact
+ ([10, 11], [3, 4], [2, 3], (10, 3, 2)), # rounded
+ ]
+
+ for x, y, z, expected_center in parameters:
+ with self.subTest(x=x, y=y, z=z):
+ samples = pd.DataFrame({"x": x, "y": y, "z": z})
+ center = get_center_voxel(samples)
+ self.assertTupleEqual(expected_center, center)
+
+ def test_get_location_voxels_no_region(self):
+ region = None
+ samples = pd.DataFrame(
+ {
+ "x": [0, 1, 1, 2, 2],
+ "y": [3, 3, 4, 5, 5],
+ "z": [6, 6, 7, 7, 8],
+ }
+ )
+
+ expected_voxels = [
+ (0, 3, 6),
+ (1, 3, 6),
+ (1, 4, 7),
+ (2, 5, 7),
+ (2, 5, 8),
+ ]
+
+ voxels = get_location_voxels(samples, region)
+
+ self.assertCountEqual(expected_voxels, voxels)
+
+ def test_get_location_voxels_with_region(self):
+ region = "A"
+ samples = pd.DataFrame(
+ {
+ "x": [0, 1, 1, 2, 2],
+ "y": [3, 3, 4, 5, 5],
+ "z": [6, 6, 7, 7, 8],
+ "region": ["A", "B", "A", "B", "A"],
+ }
+ )
+
+ expected_voxels = [
+ (0, 3, 6),
+ (1, 4, 7),
+ (2, 5, 8),
+ ]
+
+ voxels = get_location_voxels(samples, region)
+
+ self.assertCountEqual(expected_voxels, voxels)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/arcade_collection/input/test_generate_setup_file.py b/tests/arcade_collection/input/test_generate_setup_file.py
new file mode 100644
index 0000000..951488b
--- /dev/null
+++ b/tests/arcade_collection/input/test_generate_setup_file.py
@@ -0,0 +1,147 @@
+import unittest
+
+import pandas as pd
+
+from arcade_collection.input.generate_setup_file import (
+ DEFAULT_POPULATION_ID,
+ calculate_sample_bounds,
+ generate_setup_file,
+ make_setup_file,
+)
+
+
+class TestGenerateSetupFile(unittest.TestCase):
+ def setUp(self):
+ self.terms = ["term_a", "term_b", "term_c"]
+ self.margins = [10, 20, 30]
+
+ self.setup_template_no_region = (
+ "\n"
+ ' \n'
+ " \n"
+ ' \n'
+ ' \n'
+ ' \n'
+ " \n"
+ " \n"
+ " \n"
+ f' \n'
+ " \n"
+ " \n"
+ " \n"
+ ""
+ )
+
+ self.setup_template_with_region = (
+ "\n"
+ ' \n'
+ " \n"
+ ' \n'
+ ' \n'
+ ' \n'
+ " \n"
+ " \n"
+ " \n"
+ f' \n'
+ f' \n'
+ f' \n'
+ " \n"
+ " \n"
+ " \n"
+ " \n"
+ ""
+ )
+
+ def test_generate_setup_file_no_regions(self):
+ samples = pd.DataFrame(
+ {
+ "id": [1, 2],
+ "x": [0, 2],
+ "y": [3, 7],
+ "z": [6, 7],
+ }
+ )
+
+ expected_setup = self.setup_template_no_region % (25, 47, 64, 2)
+
+ setup = generate_setup_file(samples, self.margins, self.terms)
+
+ self.assertEqual(expected_setup, setup)
+
+ def test_generate_setup_file_invalid_regions(self):
+ samples = pd.DataFrame(
+ {
+ "id": [1, 2],
+ "x": [0, 2],
+ "y": [3, 7],
+ "z": [6, 7],
+ "region": [None, None],
+ }
+ )
+
+ expected_setup = self.setup_template_no_region % (25, 47, 64, 2)
+
+ setup = generate_setup_file(samples, self.margins, self.terms)
+
+ self.assertEqual(expected_setup, setup)
+
+ def test_generate_setup_file_with_regions(self):
+ samples = pd.DataFrame(
+ {
+ "id": [1, 2],
+ "x": [0, 2],
+ "y": [3, 7],
+ "z": [6, 7],
+ "region": ["A", "B"],
+ }
+ )
+
+ expected_setup = self.setup_template_with_region % (25, 47, 64, 2, "A", "B")
+
+ setup = generate_setup_file(samples, self.margins, self.terms)
+
+ self.assertEqual(expected_setup, setup)
+
+ def test_calculate_sample_bounds(self):
+ samples = pd.DataFrame(
+ {
+ "x": [0, 2],
+ "y": [3, 7],
+ "z": [6, 7],
+ }
+ )
+ margins = [10, 20, 30]
+
+ expected_bounds = (25, 47, 64)
+
+ bounds = calculate_sample_bounds(samples, margins)
+
+ self.assertTupleEqual(expected_bounds, bounds)
+
+ def test_make_setup_file_no_regions(self):
+ init = 100
+ bounds = (10, 20, 30)
+ regions = None
+
+ expected_setup = self.setup_template_no_region % (*bounds, init)
+
+ setup = make_setup_file(init, bounds, self.terms, regions)
+
+ self.assertEqual(expected_setup, setup)
+
+ def test_make_setup_file_with_regions(self):
+ init = 100
+ bounds = (10, 20, 30)
+ regions = ["REGION_A", "REGION_B"]
+
+ expected_setup = self.setup_template_with_region % (*bounds, init, regions[0], regions[1])
+
+ setup = make_setup_file(init, bounds, self.terms, regions)
+
+ self.assertEqual(expected_setup, setup)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/arcade_collection/input/test_group_template_conditions.py b/tests/arcade_collection/input/test_group_template_conditions.py
new file mode 100644
index 0000000..6600256
--- /dev/null
+++ b/tests/arcade_collection/input/test_group_template_conditions.py
@@ -0,0 +1,126 @@
+import unittest
+
+from arcade_collection.input.group_template_conditions import (
+ find_seed_ranges,
+ group_condition_sets,
+ group_seed_ranges,
+ group_template_conditions,
+)
+
+
+class TestGroupTemplateConditions(unittest.TestCase):
+ def test_group_template_conditions(self):
+ max_seeds = 3
+ conditions = [
+ {"key": "A", "seed": 1},
+ {"key": "A", "seed": 2},
+ {"key": "A", "seed": 4},
+ {"key": "B", "seed": 1},
+ {"key": "B", "seed": 2},
+ {"key": "B", "seed": 3},
+ {"key": "B", "seed": 4},
+ ]
+
+ expected_groups = [
+ {
+ "conditions": [
+ {"key": "A", "start_seed": 1, "end_seed": 2},
+ {"key": "A", "start_seed": 4, "end_seed": 4},
+ ]
+ },
+ {
+ "conditions": [
+ {"key": "B", "start_seed": 1, "end_seed": 3},
+ ]
+ },
+ {
+ "conditions": [
+ {"key": "B", "start_seed": 4, "end_seed": 4},
+ ]
+ },
+ ]
+
+ groups = group_template_conditions(conditions, max_seeds)
+
+ self.assertCountEqual(expected_groups, groups)
+
+ def test_find_seed_ranges_continuous(self):
+ seeds = [0, 1, 2, 3]
+ parameters = [
+ (2, [(0, 1), (2, 3)]), # below max, equal
+ (3, [(0, 2), (3, 3)]), # below max, unequal
+ (4, [(0, 3)]), # equal to max
+ (5, [(0, 3)]), # above max
+ ]
+
+ for max_seeds, expected_groups in parameters:
+ with self.subTest(max_seeds=max_seeds):
+ groups = find_seed_ranges(seeds, max_seeds)
+ self.assertCountEqual(expected_groups, groups)
+
+ def test_find_seed_ranges_discontinuous(self):
+ seeds = [0, 1, 2, 3, 5, 6, 7, 8]
+ parameters = [
+ (2, [(0, 1), (2, 3), (5, 6), (7, 8)]), # below max, equal
+ (3, [(0, 2), (3, 3), (5, 7), (8, 8)]), # below max, unequal
+ (4, [(0, 3), (5, 8)]), # equal to max
+ (5, [(0, 3), (5, 8)]), # above max
+ ]
+
+ for max_seeds, expected_groups in parameters:
+ with self.subTest(max_seeds=max_seeds):
+ groups = find_seed_ranges(seeds, max_seeds)
+ self.assertCountEqual(expected_groups, groups)
+
+ def test_group_seed_ranges(self):
+ max_seeds = 2
+ conditions = [
+ {"key": "A", "seed": 1},
+ {"key": "A", "seed": 2},
+ {"key": "A", "seed": 3},
+ {"key": "B", "seed": 1},
+ {"key": "B", "seed": 2},
+ {"key": "B", "seed": 3},
+ {"key": "B", "seed": 4},
+ ]
+
+ expected_groups = [
+ {"key": "A", "start_seed": 1, "end_seed": 2},
+ {"key": "A", "start_seed": 3, "end_seed": 3},
+ {"key": "B", "start_seed": 1, "end_seed": 2},
+ {"key": "B", "start_seed": 3, "end_seed": 4},
+ ]
+
+ groups = group_seed_ranges(conditions, max_seeds)
+
+ self.assertCountEqual(expected_groups, groups)
+
+ def test_group_condition_sets(self):
+ max_seeds = 3
+ conditions = [
+ {"key": "A", "start_seed": 1, "end_seed": 2},
+ {"key": "A", "start_seed": 3, "end_seed": 3},
+ {"key": "B", "start_seed": 1, "end_seed": 2},
+ {"key": "B", "start_seed": 3, "end_seed": 4},
+ ]
+
+ expected_groups = [
+ [
+ {"key": "A", "start_seed": 1, "end_seed": 2},
+ {"key": "A", "start_seed": 3, "end_seed": 3},
+ ],
+ [
+ {"key": "B", "start_seed": 1, "end_seed": 2},
+ ],
+ [
+ {"key": "B", "start_seed": 3, "end_seed": 4},
+ ],
+ ]
+
+ groups = group_condition_sets(conditions, max_seeds)
+
+ self.assertCountEqual(expected_groups, groups)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/arcade_collection/input/test_merge_region_samples.py b/tests/arcade_collection/input/test_merge_region_samples.py
new file mode 100644
index 0000000..391a422
--- /dev/null
+++ b/tests/arcade_collection/input/test_merge_region_samples.py
@@ -0,0 +1,215 @@
+import unittest
+
+import pandas as pd
+
+from arcade_collection.input.merge_region_samples import (
+ filter_valid_samples,
+ merge_region_samples,
+ transform_sample_coordinates,
+)
+
+
+class TestMergeRegionSamples(unittest.TestCase):
+ def test_merge_region_samples_no_regions(self):
+ samples = {
+ "DEFAULT": pd.DataFrame(
+ {
+ "id": [1, 1, 1, 2, 2],
+ "x": [0, 1, 1, 2, 2],
+ "y": [3, 3, 4, 5, 5],
+ "z": [6, 6, 7, 7, 8],
+ }
+ ),
+ }
+ margins = (10, 20, 30)
+
+ expected_merged = pd.DataFrame(
+ {
+ "id": [1, 1, 1, 2, 2],
+ "x": [11, 12, 12, 13, 13],
+ "y": [21, 21, 22, 23, 23],
+ "z": [31, 31, 32, 32, 33],
+ }
+ )
+
+ merged = merge_region_samples(samples, margins)
+
+ self.assertTrue(expected_merged.equals(merged))
+
+ def test_merge_region_samples_with_regions_no_fill(self):
+ samples = {
+ "DEFAULT": pd.DataFrame(
+ {
+ "id": [1, 1, 1, 2, 2],
+ "x": [0, 1, 1, 2, 2],
+ "y": [3, 3, 4, 5, 5],
+ "z": [6, 6, 7, 7, 8],
+ }
+ ),
+ "REGION_A": pd.DataFrame(
+ {"id": [1, 1, 2], "x": [0, 1, 2], "y": [3, 4, 5], "z": [6, 7, 7]}
+ ),
+ "REGION_B": pd.DataFrame({"id": [1, 2], "x": [1, 2], "y": [3, 5], "z": [6, 8]}),
+ }
+ margins = (10, 20, 30)
+
+ expected_merged = pd.DataFrame(
+ {
+ "id": [1, 1, 1, 2, 2],
+ "x": [11, 12, 12, 13, 13],
+ "y": [21, 21, 22, 23, 23],
+ "z": [31, 31, 32, 32, 33],
+ "region": ["REGION_A", "REGION_B", "REGION_A", "REGION_A", "REGION_B"],
+ }
+ )
+
+ merged = merge_region_samples(samples, margins)
+
+ self.assertTrue(expected_merged.equals(merged))
+
+ def test_merge_region_samples_with_regions_with_fill(self):
+ samples = {
+ "DEFAULT": pd.DataFrame(
+ {
+ "id": [1, 1, 1, 2, 2],
+ "x": [0, 1, 1, 2, 2],
+ "y": [3, 3, 4, 5, 5],
+ "z": [6, 6, 7, 7, 8],
+ }
+ ),
+ "REGION_A": pd.DataFrame(
+ {"id": [1, 1, 2], "x": [0, 1, 2], "y": [3, 4, 5], "z": [6, 7, 7]}
+ ),
+ }
+ margins = (10, 20, 30)
+
+ expected_merged = pd.DataFrame(
+ {
+ "id": [1, 1, 1, 2, 2],
+ "x": [11, 12, 12, 13, 13],
+ "y": [21, 21, 22, 23, 23],
+ "z": [31, 31, 32, 32, 33],
+ "region": ["REGION_A", "DEFAULT", "REGION_A", "REGION_A", "DEFAULT"],
+ }
+ )
+
+ merged = merge_region_samples(samples, margins)
+
+ self.assertTrue(expected_merged.equals(merged))
+
+ def test_transform_sample_coordinates_no_reference(self):
+ samples = pd.DataFrame(
+ {
+ "id": [1, 1, 1, 2, 2],
+ "x": [0, 1, 1, 2, 2],
+ "y": [3, 3, 4, 5, 5],
+ "z": [6, 6, 7, 7, 8],
+ }
+ )
+ margins = (10, 20, 30)
+ reference = None
+
+ expected_transformed = pd.DataFrame(
+ {
+ "id": [1, 1, 1, 2, 2],
+ "x": [11, 12, 12, 13, 13],
+ "y": [21, 21, 22, 23, 23],
+ "z": [31, 31, 32, 32, 33],
+ }
+ )
+
+ transformed = transform_sample_coordinates(samples, margins, reference)
+
+ self.assertTrue(expected_transformed.equals(transformed))
+
+ def test_transform_sample_coordinates_with_reference(self):
+ samples = pd.DataFrame(
+ {
+ "id": [1, 1, 1, 2, 2],
+ "x": [0, 1, 1, 2, 2],
+ "y": [3, 3, 4, 5, 5],
+ "z": [6, 6, 7, 7, 8],
+ }
+ )
+ margins = (10, 20, 30)
+ reference = pd.DataFrame(
+ {
+ "x": [0],
+ "y": [1],
+ "z": [2],
+ }
+ )
+
+ expected_transformed = pd.DataFrame(
+ {
+ "id": [1, 1, 1, 2, 2],
+ "x": [11, 12, 12, 13, 13],
+ "y": [23, 23, 24, 25, 25],
+ "z": [35, 35, 36, 36, 37],
+ }
+ )
+
+ transformed = transform_sample_coordinates(samples, margins, reference)
+
+ self.assertTrue(expected_transformed.equals(transformed))
+
+ def test_filter_valid_samples_no_region_all_valid(self):
+ samples = pd.DataFrame(
+ {
+ "x": [0, 1, 1, 2, 2],
+ "y": [3, 3, 4, 5, 5],
+ "z": [6, 6, 7, 7, 8],
+ }
+ )
+
+ expected_filtered = samples.copy()
+
+ filtered = filter_valid_samples(samples)
+
+ self.assertTrue(expected_filtered.equals(filtered))
+
+ def test_filter_valid_samples_with_region_all_valid(self):
+ samples = pd.DataFrame(
+ {
+ "id": [1, 1, 1, 2, 2],
+ "x": [0, 1, 1, 2, 2],
+ "y": [3, 3, 4, 5, 5],
+ "z": [6, 6, 7, 7, 8],
+ "region": ["A", "B", "B", "A", "B"],
+ }
+ )
+
+ expected_filtered = samples.copy()
+
+ filtered = filter_valid_samples(samples)
+
+ self.assertTrue(expected_filtered.equals(filtered))
+
+ def test_filter_valid_samples_sample_outside_region(self):
+ samples = pd.DataFrame(
+ {
+ "id": [1, 1, 1, 2, 2],
+ "x": [0, 1, 1, 2, 2],
+ "y": [3, 3, 4, 5, 5],
+ "z": [6, 6, 7, 7, 8],
+ "region": ["A", "B", "B", "A", "A"],
+ }
+ )
+
+ expected_filtered = pd.DataFrame(
+ {
+ "id": [1, 1, 1],
+ "x": [0, 1, 1],
+ "y": [3, 3, 4],
+ "z": [6, 6, 7],
+ "region": ["A", "B", "B"],
+ }
+ )
+
+ filtered = filter_valid_samples(samples)
+
+ self.assertTrue(expected_filtered.equals(filtered))
+
+
+if __name__ == "__main__":
+ unittest.main()