diff --git a/src/arcade_collection/convert/convert_to_simularium.py b/src/arcade_collection/convert/convert_to_simularium.py index 30ee9f0..a713583 100644 --- a/src/arcade_collection/convert/convert_to_simularium.py +++ b/src/arcade_collection/convert/convert_to_simularium.py @@ -40,10 +40,11 @@ def convert_to_simularium( dt: float, colors: dict[str, str], url: Optional[str] = None, + jitter: float = 1.0, ) -> str: meta_data = get_meta_data(series_key, simulation_type, length, width, height, ds, dz) agent_data = get_agent_data(data) - agent_data.display_data = get_display_data(data, colors, url) + agent_data.display_data = get_display_data(data, colors, url, jitter) for index, (frame, group) in enumerate(data.groupby("frame")): n_agents = len(group) @@ -114,43 +115,102 @@ def get_agent_data(data: pd.DataFrame) -> AgentData: def get_display_data( - data: pd.DataFrame, colors: dict[str, str], url: Optional[str] = None + data: pd.DataFrame, + colors: dict[str, str], + url: str = "", + jitter: float = 1.0, ) -> DisplayData: + """ + Create map of DisplayData objects. + + Method uses the "name" and "display_type" columns in data to generate the + DisplayData objects. + + The "name" column should be a string in one of the following forms: + + - ``(index)#(color_key)`` + - ``(group)#(color_key)#(index)`` + - ``(group)#(color_key)#(index)#(frame)`` + + where ``(index)`` becomes DisplayData object name and ``(color_key)`` is + passed to the color mapping to select the DisplayData color (optional color + jitter may be applied). + + The "display_type" column should be a valid ``DISPLAY_TYPE``. For the + ``DISPLAY_TYPE.OBJ`` type, a URL prefix must be used and names should be in + the form ``(group)#(color_key)#(index)#(frame)``, which is used to generate + the full URL formatted as: ``(url)/(frame)_(group)_(index).MESH.obj``. Note + that ``(frame)`` is zero-padded to six digits and ``(index)`` is zero-padded + to three digits. + + Parameters + ---------- + data + Simulation trajectory data. + colors + Color mapping. + url + Url prefix for meshes. + jitter + Jitter applied to colors. + + Returns + ------- + : + Map of DisplayData objects. + """ + display_data = {} + display_types = sorted(set(zip(data["name"], data["display_type"]))) - for name in data["name"].unique(): - if name.count("#") == 3: - group, color_key, index, frame = name.split("#") + for name, display_type in display_types: + if name.count("#") == 1: + index, color_key = name.split("#") elif name.count("#") == 2: - group, index, color_key = name.split("#") + _, color_key, index = name.split("#") + elif name.count("#") == 3: + group, color_key, index, frame = name.split("#") - random.seed(index) - jitter = (random.random() - 0.5) / 2 - - if url is not None: - display_data[name] = DisplayData( - name=index, - display_type=DISPLAY_TYPE.OBJ, - url=f"{url}/{int(frame):06d}_{group}_{int(index):03d}.MESH.obj", - color=shade_color(colors[color_key], jitter), - ) - elif index is None: - display_data[name] = DisplayData( - name=group, - display_type=DISPLAY_TYPE.FIBER, - color=colors[color_key], - ) + if url != "": + full_url = f"{url}/{int(frame):06d}_{group}_{int(index):03d}.MESH.obj" else: - display_data[name] = DisplayData( - name=index, - display_type=DISPLAY_TYPE.SPHERE, - color=shade_color(colors[color_key], jitter), - ) + full_url = "" + + random.seed(index) + alpha = jitter * (random.random() - 0.5) / 2 + + display_data[name] = DisplayData( + name=index, + display_type=DISPLAY_TYPE[display_type], + color=shade_color(colors[color_key], alpha), + url=full_url, + ) return display_data def shade_color(color: str, alpha: float) -> str: + """ + Shade color by specified alpha. + + Positive values of alpha will blend the given color with white (alpha = 1.0 + returns pure white), while negative values of alpha will blend the given + color with black (alpha = -1.0 returns pure black). An alpha = 0.0 will + leave the color unchanged. + + Parameters + ---------- + color + Original color as hex string. + alpha + Shading value between -1 and +1. + + Returns + ------- + : + Shaded color as hex string. + """ + old_color = color.replace("#", "") old_red, old_green, old_blue = [int(old_color[i : i + 2], 16) for i in (0, 2, 4)] layer_color = 0 if alpha < 0 else 255 diff --git a/tests/arcade_collection/convert/test_convert_to_simularium.py b/tests/arcade_collection/convert/test_convert_to_simularium.py new file mode 100644 index 0000000..171cf56 --- /dev/null +++ b/tests/arcade_collection/convert/test_convert_to_simularium.py @@ -0,0 +1,108 @@ +import sys +import unittest +from unittest import mock + +import pandas as pd +from simulariumio import DISPLAY_TYPE + +from arcade_collection.convert.convert_to_simularium import get_display_data, shade_color + + +class TestConvertToSimularium(unittest.TestCase): + def test_get_display_data_no_url(self) -> None: + data = pd.DataFrame( + { + "name": ["GROUP#A", "GROUP#A#4", "GROUP#B#3", "GROUP#A#2", "GROUP#B#1"], + "display_type": ["SPHERE", "SPHERE", "FIBER", "SPHERE", "FIBER"], + } + ) + colors = {"A": "#ff0000", "B": "#0000ff"} + + expected_data = [ + ("GROUP#A", "GROUP", colors["A"], DISPLAY_TYPE.SPHERE, ""), + ("GROUP#A#2", "2", colors["A"], DISPLAY_TYPE.SPHERE, ""), + ("GROUP#A#4", "4", colors["A"], DISPLAY_TYPE.SPHERE, ""), + ("GROUP#B#1", "1", colors["B"], DISPLAY_TYPE.FIBER, ""), + ("GROUP#B#3", "3", colors["B"], DISPLAY_TYPE.FIBER, ""), + ] + + display_data = get_display_data(data, colors, url="", jitter=0.0) + + for expected, (key, display) in zip(expected_data, display_data.items()): + self.assertTupleEqual( + expected, (key, display.name, display.color, display.display_type, display.url) + ) + + def test_get_display_data_with_url(self) -> None: + url = "https://url/" + data = pd.DataFrame( + { + "name": ["GROUP#A#3#1", "GROUP#A#2#1", "GROUP#B#1#1", "GROUP#A#2#0", "GROUP#B#1#0"], + "display_type": ["OBJ", "OBJ", "OBJ", "OBJ", "OBJ"], + } + ) + colors = {"A": "#ff0000", "B": "#0000ff"} + + expected_data = [ + ("GROUP#A#2#0", "2", colors["A"], DISPLAY_TYPE.OBJ, f"{url}/000000_GROUP_002.MESH.obj"), + ("GROUP#A#2#1", "2", colors["A"], DISPLAY_TYPE.OBJ, f"{url}/000001_GROUP_002.MESH.obj"), + ("GROUP#A#3#1", "3", colors["A"], DISPLAY_TYPE.OBJ, f"{url}/000001_GROUP_003.MESH.obj"), + ("GROUP#B#1#0", "1", colors["B"], DISPLAY_TYPE.OBJ, f"{url}/000000_GROUP_001.MESH.obj"), + ("GROUP#B#1#1", "1", colors["B"], DISPLAY_TYPE.OBJ, f"{url}/000001_GROUP_001.MESH.obj"), + ] + + display_data = get_display_data(data, colors, url=url, jitter=0.0) + + for expected, (key, display) in zip(expected_data, display_data.items()): + self.assertTupleEqual( + expected, (key, display.name, display.color, display.display_type, display.url) + ) + + @mock.patch.object( + sys.modules["arcade_collection.convert.convert_to_simularium"], + "random", + return_value=mock.Mock(), + ) + def test_get_display_data_with_jitter(self, random_mock) -> None: + random_mock.random.side_effect = [0.1, 0.3, 0.7, 0.9] + + data = pd.DataFrame( + { + "name": ["GROUP#A#1", "GROUP#A#2", "GROUP#A#3", "GROUP#A#4"], + "display_type": ["SPHERE", "SPHERE", "SPHERE", "SPHERE"], + } + ) + jitter = 0.5 + color = "#ff55ee" + colors = {"A": color} + + expected_colors = [ + shade_color(color, -0.2 * jitter), + shade_color(color, -0.1 * jitter), + shade_color(color, 0.1 * jitter), + shade_color(color, 0.2 * jitter), + ] + + display_data = get_display_data(data, colors, url="", jitter=jitter) + + for expected_color, display in zip(expected_colors, display_data.values()): + self.assertEqual(expected_color, display.color) + + def test_shade_color(self) -> None: + original_color = "#F0F00F" + parameters = [ + (0.0, "#F0F00F"), # unchanged + (-1.0, "#000000"), # full shade to black + (1.0, "#FFFFFF"), # full shade to white + (-0.5, "#787808"), # half shade to black + (0.5, "#F8F887"), # half shade to white + ] + + for alpha, expected_color in parameters: + with self.subTest(alpha=alpha): + color = shade_color(original_color, alpha) + self.assertEqual(expected_color.lower(), color.lower()) + + +if __name__ == "__main__": + unittest.main()