From e1e4260e4d9b8e780a3589a4edde7275be703d14 Mon Sep 17 00:00:00 2001
From: jessicasyu <15913767+jessicasyu@users.noreply.github.com>
Date: Tue, 3 Sep 2024 14:49:46 -0400
Subject: [PATCH] Add docstrings and unit tests for convert to cells file task

---
 .../input/convert_to_cells_file.py            | 125 ++++-
 .../input/test_convert_to_cells_file.py       | 442 ++++++++++++++++++
 2 files changed, 549 insertions(+), 18 deletions(-)
 create mode 100644 tests/arcade_collection/input/test_convert_to_cells_file.py

diff --git a/src/arcade_collection/input/convert_to_cells_file.py b/src/arcade_collection/input/convert_to_cells_file.py
index b3261d3..5c75155 100644
--- a/src/arcade_collection/input/convert_to_cells_file.py
+++ b/src/arcade_collection/input/convert_to_cells_file.py
@@ -10,6 +10,44 @@ def convert_to_cells_file(
     critical_height_distributions: dict[str, tuple[float, float]],
     state_thresholds: dict[str, float],
 ) -> list[dict]:
+    """
+    Convert all samples to cell objects.
+
+    For each cell id in samples, current volume and height are rescaled to
+    critical volume and critical height based on distribution means and standard
+    deviations. If reference volume and/or height exist for the cell id, those
+    values are used as the current values to be rescaled. Otherwise, current
+    volume is calculated from the number of voxel samples and current height is
+    calculated from the range of voxel coordinates along the z axis.
+
+    Initial cell state and cell state phase for each cell are estimated based on
+    state thresholds, the current cell volume, and the critical cell volume.
+
+    Cell object ids are reindexed starting with cell id 1.
+
+    Parameters
+    ----------
+    samples
+        Sample cell ids and coordinates.
+    reference
+        Reference values for volumes and heights.
+    volume_distributions
+        Map of volume means and standard deviations.
+    height_distributions
+        Map of height means and standard deviations.
+    critical_volume_distributions
+        Map of critical volume means and standard deviations.
+    critical_height_distributions
+        Map of critical height means and standard deviations.
+    state_thresholds
+        Critical volume fractions defining threshold between states.
+
+    Returns
+    -------
+    :
+        List of cell objects formatted for ARCADE.
+    """
+
     cells: list[dict] = []
     samples_by_id = samples.groupby("id")
 
@@ -42,30 +80,43 @@ def convert_to_cell(
     state_thresholds: dict[str, float],
 ) -> dict:
     """
-    Convert samples to ARCADE .CELLS json format.
+    Convert samples to cell object.
+
+    Current volume and height are rescaled to critical volume and critical
+    height based on distribution means and standard deviations. If reference
+    volume and/or height are provided (under the "DEFAULT" key), those values
+    are used as the current values to be rescaled. Otherwise, current volume is
+    calculated from the number of voxel samples and current height is calculated
+    from the range of voxel coordinates along the z axis.
+
+    Initial cell state and cell state phase are estimated based on state
+    thresholds, the current cell volume, and the critical cell volume.
 
     Parameters
     ----------
     cell_id
         Unique cell id.
     samples
-        Sample cell ids and coordinates.
+        Sample coordinates for a single object.
     reference
-        Reference data for conversion.
-    volume_distribution
-        Average and standard deviation of volume distributions.
-    height_distribution
-        Average and standard deviation of height distributions.
-    critical_volume_distribution
-        Average and standard deviation of critical volume distributions.
-    critical_height_distribution
-        Average and standard deviation of critical height distributions.
+        Reference data for cell.
+    volume_distributions
+        Map of volume means and standard deviations.
+    height_distributions
+        Map of height means and standard deviations.
+    critical_volume_distributions
+        Map of critical volume means and standard deviations.
+    critical_height_distributions
+        Map of critical height means and standard deviations.
+    state_thresholds
+        Critical volume fractions defining threshold between states.
 
     Returns
     -------
     :
-        Dictionary in ARCADE .CELLS json format.
+        Cell object formatted for ARCADE.
     """
+
     volume = len(samples)
     height = samples.z.max() - samples.z.min()
 
@@ -95,7 +146,7 @@ def convert_to_cell(
         "criticals": [critical_volume, critical_height],
     }
 
-    if "region" in samples.columns:
+    if "region" in samples.columns and not samples["region"].isnull().all():
         regions = [
             convert_to_cell_region(
                 region,
@@ -122,6 +173,39 @@ def convert_to_cell_region(
     critical_volume_distributions: dict[str, tuple[float, float]],
     critical_height_distributions: dict[str, tuple[float, float]],
 ) -> dict:
+    """
+    Convert region samples to cell region object.
+
+    Current region volume and height are rescaled to critical volume and
+    critical height based on distribution means and standard deviations. If
+    reference region volume and/or height are provided, those values are used as
+    the current values to be rescaled. Otherwise, current region volume is
+    calculated from the number of voxel samples and current region height is
+    calculated from the range of voxel coordinates along the z axis.
+
+    Parameters
+    ----------
+    region
+        Region name.
+    region_samples
+        Sample coordinates for region of a single object.
+    reference
+        Reference data for cell region.
+    volume_distributions
+        Map of volume means and standard deviations.
+    height_distributions
+        Map of height means and standard deviations.
+    critical_volume_distributions
+        Map of critical volume means and standard deviations.
+    critical_height_distributions
+        Map of critical height means and standard deviations.
+
+    Returns
+    -------
+    :
+        Cell region object formatted for ARCADE.
+    """
+
     region_volume = len(region_samples)
     region_height = region_samples.z.max() - region_samples.z.min()
 
@@ -152,12 +236,14 @@ def get_cell_state(
     """
     Estimates cell state based on cell volume.
 
-    The threshold fractions dictionary defines the monotonic thresholds
-    between different cell states.
-    For a given volume v, critical volume V, and states X1, X2, ..., XN with
-    corresponding, monotonic threshold fractions f1, f2, ..., fN, a cell is
-    assigned state Xi such that [f(i - 1) * V] <= v < [fi * V].
+    The threshold fractions dictionary defines the monotonic thresholds between
+    different cell states. For a given volume v, critical volume V, and states
+    X1, X2, ..., XN with corresponding, monotonic threshold fractions f1, f2,
+    ..., fN, a cell is assigned state Xi such that [f(i - 1) * V] <= v < [fi *
+    V].
+
     Cells with v < f1 * V are assigned state X1.
+
     Cells with v > fN * V are assigned state XN.
 
     Parameters
@@ -174,6 +260,7 @@ def get_cell_state(
     :
         Cell state.
     """
+
     thresholds = [fraction * critical_volume for fraction in threshold_fractions.values()]
     states = list(threshold_fractions.keys())
 
@@ -203,6 +290,7 @@ def convert_value_distribution(
     :
         Estimated critical value.
     """
+
     source_avg, source_std = source_distribution
     target_avg, target_std = target_distribution
     z_scored_value = (value - source_avg) / source_std
@@ -226,6 +314,7 @@ def filter_cell_reference(cell_id: int, reference: pd.DataFrame) -> dict:
     :
         Reference data for given cell id.
     """
+
     cell_reference = reference[reference["ID"] == cell_id].squeeze()
     cell_reference = cell_reference.to_dict() if not cell_reference.empty else {}
     return cell_reference
diff --git a/tests/arcade_collection/input/test_convert_to_cells_file.py b/tests/arcade_collection/input/test_convert_to_cells_file.py
new file mode 100644
index 0000000..f029d2a
--- /dev/null
+++ b/tests/arcade_collection/input/test_convert_to_cells_file.py
@@ -0,0 +1,442 @@
+import unittest
+
+import numpy as np
+import pandas as pd
+
+from arcade_collection.input.convert_to_cells_file import (
+    convert_to_cell,
+    convert_to_cell_region,
+    convert_to_cells_file,
+    convert_value_distribution,
+    filter_cell_reference,
+    get_cell_state,
+)
+
+EPSILON = 1e-10
+DEFAULT_REGION_NAME = "DEFAULT"
+
+
+def make_samples(cell_id, volume, height, region):
+    return pd.DataFrame(
+        {
+            "id": [cell_id] * volume,
+            "z": np.linspace(0, height, volume),
+            "region": [region] * volume,
+        }
+    )
+
+
+class TestConvertToCellsFile(unittest.TestCase):
+    def setUp(self):
+        self.volume_distributions = {
+            DEFAULT_REGION_NAME: (10, 10),
+            "REGION1": (10, 5),
+            "REGION2": (25, 10),
+        }
+        self.height_distributions = {
+            DEFAULT_REGION_NAME: (5, 5),
+            "REGION1": (4, 1),
+            "REGION2": (10, 5),
+        }
+        self.critical_volume_distributions = {
+            DEFAULT_REGION_NAME: (2, 2),
+            "REGION1": (2, 4),
+            "REGION2": (6, 1),
+        }
+        self.critical_height_distributions = {
+            DEFAULT_REGION_NAME: (10, 10),
+            "REGION1": (2, 3),
+            "REGION2": (8, 4),
+        }
+        self.state_thresholds = {
+            "STATE1_PHASE1": 0.2,
+            "STATE1_PHASE2": 1.5,
+            "STATE2_PHASE1": 2,
+        }
+
+        self.reference = {
+            "volume": 10,
+            "height": 10,
+            "volume.REGION1": 10,
+            "height.REGION1": 4,
+            "volume.REGION2": 15,
+            "height.REGION2": 10,
+        }
+
+    def test_convert_to_cells_file(self):
+        cell_ids = [10, 11, 12, 13]
+        volumes = [[20], [20], [15, 5], [40, 10]]
+        heights = [5, 5, 5, 20]
+
+        samples = pd.concat(
+            [
+                (
+                    make_samples(cell_id, volume, height, f"REGION{index + 1}")
+                    if len(region_volumes) > 1
+                    else make_samples(cell_id, volume, height, None)
+                )
+                for cell_id, height, region_volumes in zip(cell_ids, heights, volumes)
+                for index, volume in enumerate(region_volumes)
+            ]
+        )
+
+        reference = pd.DataFrame(
+            {
+                "ID": [11, 13],
+                "volume": [10, 10],
+                "height": [10, 10],
+                "volume.REGION1": [None, 10],
+                "height.REGION1": [None, 4],
+                "volume.REGION2": [None, 15],
+                "height.REGION2": [None, 10],
+            }
+        )
+
+        expected_cells = [
+            {
+                "id": 1,
+                "parent": 0,
+                "pop": 1,
+                "age": 0,
+                "divisions": 0,
+                "state": "STATE2",
+                "phase": "STATE2_PHASE1",
+                "voxels": np.sum(volumes[0]),
+                "criticals": [4, 10],
+            },
+            {
+                "id": 2,
+                "parent": 0,
+                "pop": 1,
+                "age": 0,
+                "divisions": 0,
+                "state": "STATE2",
+                "phase": "STATE2_PHASE1",
+                "voxels": np.sum(volumes[1]),
+                "criticals": [2, 20],
+            },
+            {
+                "id": 3,
+                "parent": 0,
+                "pop": 1,
+                "age": 0,
+                "divisions": 0,
+                "state": "STATE2",
+                "phase": "STATE2_PHASE1",
+                "voxels": np.sum(volumes[2]),
+                "criticals": [4, 10],
+                "regions": [
+                    {"region": "REGION1", "voxels": volumes[2][0], "criticals": [6, 5]},
+                    {"region": "REGION2", "voxels": volumes[2][1], "criticals": [4, 4]},
+                ],
+            },
+            {
+                "id": 4,
+                "parent": 0,
+                "pop": 1,
+                "age": 0,
+                "divisions": 0,
+                "state": "STATE2",
+                "phase": "STATE2_PHASE1",
+                "voxels": np.sum(volumes[3]),
+                "criticals": [2, 20],
+                "regions": [
+                    {"region": "REGION1", "voxels": volumes[3][0], "criticals": [2, 2]},
+                    {"region": "REGION2", "voxels": volumes[3][1], "criticals": [5, 8]},
+                ],
+            },
+        ]
+
+        cells = convert_to_cells_file(
+            samples,
+            reference,
+            self.volume_distributions,
+            self.height_distributions,
+            self.critical_volume_distributions,
+            self.critical_height_distributions,
+            self.state_thresholds,
+        )
+
+        self.assertCountEqual(expected_cells, cells)
+
+    def test_convert_to_cell_no_reference_no_region(self):
+        cell_id = 2
+        volume = 20
+        height = 5
+        samples = make_samples(cell_id, volume, height, None)
+        reference = {}
+
+        expected_cell = {
+            "id": cell_id,
+            "parent": 0,
+            "pop": 1,
+            "age": 0,
+            "divisions": 0,
+            "state": "STATE2",
+            "phase": "STATE2_PHASE1",
+            "voxels": volume,
+            "criticals": [4, 10],
+        }
+
+        cell = convert_to_cell(
+            cell_id,
+            samples,
+            reference,
+            self.volume_distributions,
+            self.height_distributions,
+            self.critical_volume_distributions,
+            self.critical_height_distributions,
+            self.state_thresholds,
+        )
+
+        self.assertDictEqual(expected_cell, cell)
+
+    def test_convert_to_cell_with_reference_no_region(self):
+        cell_id = 2
+        volume = 20
+        height = 5
+        samples = make_samples(cell_id, volume, height, None)
+        reference = {"volume": 10, "height": 10}
+
+        expected_cell = {
+            "id": cell_id,
+            "parent": 0,
+            "pop": 1,
+            "age": 0,
+            "divisions": 0,
+            "state": "STATE2",
+            "phase": "STATE2_PHASE1",
+            "voxels": volume,
+            "criticals": [2, 20],
+        }
+
+        cell = convert_to_cell(
+            cell_id,
+            samples,
+            reference,
+            self.volume_distributions,
+            self.height_distributions,
+            self.critical_volume_distributions,
+            self.critical_height_distributions,
+            self.state_thresholds,
+        )
+
+        self.assertDictEqual(expected_cell, cell)
+
+    def test_convert_to_cell_no_reference_with_region(self):
+        cell_id = 2
+        volumes = [15, 5]
+        height = 5
+        samples = pd.concat(
+            [
+                make_samples(cell_id, volume, height, f"REGION{index + 1}")
+                for index, volume in enumerate(volumes)
+            ]
+        )
+        reference = {}
+
+        expected_cell = {
+            "id": cell_id,
+            "parent": 0,
+            "pop": 1,
+            "age": 0,
+            "divisions": 0,
+            "state": "STATE2",
+            "phase": "STATE2_PHASE1",
+            "voxels": np.sum(volumes),
+            "criticals": [4, 10],
+            "regions": [
+                {"region": "REGION1", "voxels": volumes[0], "criticals": [6, 5]},
+                {"region": "REGION2", "voxels": volumes[1], "criticals": [4, 4]},
+            ],
+        }
+
+        cell = convert_to_cell(
+            cell_id,
+            samples,
+            reference,
+            self.volume_distributions,
+            self.height_distributions,
+            self.critical_volume_distributions,
+            self.critical_height_distributions,
+            self.state_thresholds,
+        )
+
+        self.assertDictEqual(expected_cell, cell)
+
+    def test_convert_to_cell_with_reference_with_region(self):
+        cell_id = 2
+        volumes = [15, 5]
+        height = 5
+        samples = pd.concat(
+            [
+                make_samples(cell_id, volume, height, f"REGION{index + 1}")
+                for index, volume in enumerate(volumes)
+            ]
+        )
+        reference = {
+            "volume": 10,
+            "height": 10,
+            "volume.REGION1": 10,
+            "height.REGION1": 4,
+            "volume.REGION2": 15,
+            "height.REGION2": 10,
+        }
+
+        expected_cell = {
+            "id": cell_id,
+            "parent": 0,
+            "pop": 1,
+            "age": 0,
+            "divisions": 0,
+            "state": "STATE2",
+            "phase": "STATE2_PHASE1",
+            "voxels": np.sum(volumes),
+            "criticals": [2, 20],
+            "regions": [
+                {"region": "REGION1", "voxels": volumes[0], "criticals": [2, 2]},
+                {"region": "REGION2", "voxels": volumes[1], "criticals": [5, 8]},
+            ],
+        }
+
+        cell = convert_to_cell(
+            cell_id,
+            samples,
+            reference,
+            self.volume_distributions,
+            self.height_distributions,
+            self.critical_volume_distributions,
+            self.critical_height_distributions,
+            self.state_thresholds,
+        )
+
+        self.assertDictEqual(expected_cell, cell)
+
+    def test_convert_to_cell_region_no_reference(self):
+        volume = 20
+        height = 5
+        region_samples = pd.DataFrame({"z": np.linspace(0, height, volume)})
+        reference = {}
+
+        expected_cell_region = {
+            "region": DEFAULT_REGION_NAME,
+            "voxels": volume,
+            "criticals": [4, 10],
+        }
+
+        cell_region = convert_to_cell_region(
+            DEFAULT_REGION_NAME,
+            region_samples,
+            reference,
+            self.volume_distributions,
+            self.height_distributions,
+            self.critical_volume_distributions,
+            self.critical_height_distributions,
+        )
+
+        self.assertDictEqual(expected_cell_region, cell_region)
+
+    def test_convert_to_cell_region_with_reference(self):
+        volume = 20
+        height = 5
+        region_samples = pd.DataFrame({"z": np.linspace(0, height, volume)})
+        reference = {f"volume.{DEFAULT_REGION_NAME}": 10, f"height.{DEFAULT_REGION_NAME}": 10}
+
+        expected_cell_region = {
+            "region": DEFAULT_REGION_NAME,
+            "voxels": volume,
+            "criticals": [2, 20],
+        }
+
+        cell_region = convert_to_cell_region(
+            DEFAULT_REGION_NAME,
+            region_samples,
+            reference,
+            self.volume_distributions,
+            self.height_distributions,
+            self.critical_volume_distributions,
+            self.critical_height_distributions,
+        )
+
+        self.assertDictEqual(expected_cell_region, cell_region)
+
+    def test_get_cell_state(self):
+        critical_volume = 10
+        threshold_fractions = {
+            "STATE_A": 0.2,
+            "STATE_B": 1.5,
+            "STATE_C": 2,
+        }
+
+        threshold_a = threshold_fractions["STATE_A"] * critical_volume
+        threshold_b = threshold_fractions["STATE_B"] * critical_volume
+        threshold_c = threshold_fractions["STATE_C"] * critical_volume
+
+        parameters = [
+            (threshold_a - EPSILON, "STATE_A"),  # below A threshold
+            (threshold_a, "STATE_B"),  # equal A threshold
+            ((threshold_a + threshold_b) / 2, "STATE_B"),  # between A and B thresholds
+            (threshold_b, "STATE_C"),  # equal B threshold
+            ((threshold_b + threshold_c) / 2, "STATE_C"),  # between B and C thresholds
+            (threshold_c, "STATE_C"),  # equal C threshold
+            (threshold_c + EPSILON, "STATE_C"),  # above C threshold
+        ]
+
+        for volume, expected_phase in parameters:
+            with self.subTest(volume=volume, expected_phase=expected_phase):
+                phase = get_cell_state(volume, critical_volume, threshold_fractions)
+                self.assertEqual(expected_phase, phase)
+
+    def test_convert_value_distribution(self):
+        source_distribution = (10, 6)
+        target_distribution = (2, 0.6)
+
+        parameters = [
+            (10, 2),  # means
+            (4, 1.4),  # one standard deviation below
+            (16, 2.6),  # one standard deviation above
+            (7, 1.7),  # half standard deviation below
+            (13, 2.3),  # half standard deviation above
+        ]
+
+        for source_value, expected_target_value in parameters:
+            with self.subTest(source_value=source_value):
+                target_value = convert_value_distribution(
+                    source_value, source_distribution, target_distribution
+                )
+                self.assertEqual(expected_target_value, target_value)
+
+    def test_filter_cell_reference_cell_exists(self):
+        cell_ids = [1, 2, 3]
+        feature_a = [10, 20, 30]
+        feature_b = ["a", "b", "c"]
+        cell_id = 2
+        index = cell_ids.index(cell_id)
+
+        reference = pd.DataFrame({"ID": cell_ids, "FEATURE_A": feature_a, "FEATURE_B": feature_b})
+
+        expected_cell_reference = {
+            "ID": cell_id,
+            "FEATURE_A": feature_a[index],
+            "FEATURE_B": feature_b[index],
+        }
+
+        cell_reference = filter_cell_reference(cell_id, reference)
+
+        self.assertDictEqual(expected_cell_reference, cell_reference)
+
+    def test_filter_cell_reference_cell_does_not_exist(self):
+        cell_ids = [1, 2, 3]
+        feature_a = [10, 20, 30]
+        feature_b = ["a", "b", "c"]
+        cell_id = 4
+
+        reference = pd.DataFrame({"ID": cell_ids, "FEATURE_A": feature_a, "FEATURE_B": feature_b})
+
+        cell_reference = filter_cell_reference(cell_id, reference)
+
+        self.assertDictEqual({}, cell_reference)
+
+
+if __name__ == "__main__":
+    unittest.main()