Skip to content

Commit

Permalink
Start adding unit tests and documentation for convert to simularium task
Browse files Browse the repository at this point in the history
  • Loading branch information
jessicasyu committed Sep 9, 2024
1 parent 510d162 commit 5f94bdb
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 27 deletions.
114 changes: 87 additions & 27 deletions src/arcade_collection/convert/convert_to_simularium.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ def convert_to_simularium(
dt: float,
colors: dict[str, str],
url: Optional[str] = None,
jitter: float = 1.0,
) -> str:
meta_data = get_meta_data(series_key, simulation_type, length, width, height, ds, dz)
agent_data = get_agent_data(data)
agent_data.display_data = get_display_data(data, colors, url)
agent_data.display_data = get_display_data(data, colors, url, jitter)

for index, (frame, group) in enumerate(data.groupby("frame")):
n_agents = len(group)
Expand Down Expand Up @@ -114,43 +115,102 @@ def get_agent_data(data: pd.DataFrame) -> AgentData:


def get_display_data(
data: pd.DataFrame, colors: dict[str, str], url: Optional[str] = None
data: pd.DataFrame,
colors: dict[str, str],
url: str = "",
jitter: float = 1.0,
) -> DisplayData:
"""
Create map of DisplayData objects.
Method uses the "name" and "display_type" columns in data to generate the
DisplayData objects.
The "name" column should be a string in one of the following forms:
- ``(index)#(color_key)``
- ``(group)#(color_key)#(index)``
- ``(group)#(color_key)#(index)#(frame)``
where ``(index)`` becomes DisplayData object name and ``(color_key)`` is
passed to the color mapping to select the DisplayData color (optional color
jitter may be applied).
The "display_type" column should be a valid ``DISPLAY_TYPE``. For the
``DISPLAY_TYPE.OBJ`` type, a URL prefix must be used and names should be in
the form ``(group)#(color_key)#(index)#(frame)``, which is used to generate
the full URL formatted as: ``(url)/(frame)_(group)_(index).MESH.obj``. Note
that ``(frame)`` is zero-padded to six digits and ``(index)`` is zero-padded
to three digits.
Parameters
----------
data
Simulation trajectory data.
colors
Color mapping.
url
Url prefix for meshes.
jitter
Jitter applied to colors.
Returns
-------
:
Map of DisplayData objects.
"""

display_data = {}
display_types = sorted(set(zip(data["name"], data["display_type"])))

for name in data["name"].unique():
if name.count("#") == 3:
group, color_key, index, frame = name.split("#")
for name, display_type in display_types:
if name.count("#") == 1:
index, color_key = name.split("#")
elif name.count("#") == 2:
group, index, color_key = name.split("#")
_, color_key, index = name.split("#")
elif name.count("#") == 3:
group, color_key, index, frame = name.split("#")

random.seed(index)
jitter = (random.random() - 0.5) / 2

if url is not None:
display_data[name] = DisplayData(
name=index,
display_type=DISPLAY_TYPE.OBJ,
url=f"{url}/{int(frame):06d}_{group}_{int(index):03d}.MESH.obj",
color=shade_color(colors[color_key], jitter),
)
elif index is None:
display_data[name] = DisplayData(
name=group,
display_type=DISPLAY_TYPE.FIBER,
color=colors[color_key],
)
if url != "":
full_url = f"{url}/{int(frame):06d}_{group}_{int(index):03d}.MESH.obj"
else:
display_data[name] = DisplayData(
name=index,
display_type=DISPLAY_TYPE.SPHERE,
color=shade_color(colors[color_key], jitter),
)
full_url = ""

random.seed(index)
alpha = jitter * (random.random() - 0.5) / 2

display_data[name] = DisplayData(
name=index,
display_type=DISPLAY_TYPE[display_type],
color=shade_color(colors[color_key], alpha),
url=full_url,
)

return display_data


def shade_color(color: str, alpha: float) -> str:
"""
Shade color by specified alpha.
Positive values of alpha will blend the given color with white (alpha = 1.0
returns pure white), while negative values of alpha will blend the given
color with black (alpha = -1.0 returns pure black). An alpha = 0.0 will
leave the color unchanged.
Parameters
----------
color
Original color as hex string.
alpha
Shading value between -1 and +1.
Returns
-------
:
Shaded color as hex string.
"""

old_color = color.replace("#", "")
old_red, old_green, old_blue = [int(old_color[i : i + 2], 16) for i in (0, 2, 4)]
layer_color = 0 if alpha < 0 else 255
Expand Down
108 changes: 108 additions & 0 deletions tests/arcade_collection/convert/test_convert_to_simularium.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import sys
import unittest
from unittest import mock

import pandas as pd
from simulariumio import DISPLAY_TYPE

from arcade_collection.convert.convert_to_simularium import get_display_data, shade_color


class TestConvertToSimularium(unittest.TestCase):
def test_get_display_data_no_url(self) -> None:
data = pd.DataFrame(
{
"name": ["GROUP#A", "GROUP#A#4", "GROUP#B#3", "GROUP#A#2", "GROUP#B#1"],
"display_type": ["SPHERE", "SPHERE", "FIBER", "SPHERE", "FIBER"],
}
)
colors = {"A": "#ff0000", "B": "#0000ff"}

expected_data = [
("GROUP#A", "GROUP", colors["A"], DISPLAY_TYPE.SPHERE, ""),
("GROUP#A#2", "2", colors["A"], DISPLAY_TYPE.SPHERE, ""),
("GROUP#A#4", "4", colors["A"], DISPLAY_TYPE.SPHERE, ""),
("GROUP#B#1", "1", colors["B"], DISPLAY_TYPE.FIBER, ""),
("GROUP#B#3", "3", colors["B"], DISPLAY_TYPE.FIBER, ""),
]

display_data = get_display_data(data, colors, url="", jitter=0.0)

for expected, (key, display) in zip(expected_data, display_data.items()):
self.assertTupleEqual(
expected, (key, display.name, display.color, display.display_type, display.url)
)

def test_get_display_data_with_url(self) -> None:
url = "https://url/"
data = pd.DataFrame(
{
"name": ["GROUP#A#3#1", "GROUP#A#2#1", "GROUP#B#1#1", "GROUP#A#2#0", "GROUP#B#1#0"],
"display_type": ["OBJ", "OBJ", "OBJ", "OBJ", "OBJ"],
}
)
colors = {"A": "#ff0000", "B": "#0000ff"}

expected_data = [
("GROUP#A#2#0", "2", colors["A"], DISPLAY_TYPE.OBJ, f"{url}/000000_GROUP_002.MESH.obj"),
("GROUP#A#2#1", "2", colors["A"], DISPLAY_TYPE.OBJ, f"{url}/000001_GROUP_002.MESH.obj"),
("GROUP#A#3#1", "3", colors["A"], DISPLAY_TYPE.OBJ, f"{url}/000001_GROUP_003.MESH.obj"),
("GROUP#B#1#0", "1", colors["B"], DISPLAY_TYPE.OBJ, f"{url}/000000_GROUP_001.MESH.obj"),
("GROUP#B#1#1", "1", colors["B"], DISPLAY_TYPE.OBJ, f"{url}/000001_GROUP_001.MESH.obj"),
]

display_data = get_display_data(data, colors, url=url, jitter=0.0)

for expected, (key, display) in zip(expected_data, display_data.items()):
self.assertTupleEqual(
expected, (key, display.name, display.color, display.display_type, display.url)
)

@mock.patch.object(
sys.modules["arcade_collection.convert.convert_to_simularium"],
"random",
return_value=mock.Mock(),
)
def test_get_display_data_with_jitter(self, random_mock) -> None:
random_mock.random.side_effect = [0.1, 0.3, 0.7, 0.9]

data = pd.DataFrame(
{
"name": ["GROUP#A#1", "GROUP#A#2", "GROUP#A#3", "GROUP#A#4"],
"display_type": ["SPHERE", "SPHERE", "SPHERE", "SPHERE"],
}
)
jitter = 0.5
color = "#ff55ee"
colors = {"A": color}

expected_colors = [
shade_color(color, -0.2 * jitter),
shade_color(color, -0.1 * jitter),
shade_color(color, 0.1 * jitter),
shade_color(color, 0.2 * jitter),
]

display_data = get_display_data(data, colors, url="", jitter=jitter)

for expected_color, display in zip(expected_colors, display_data.values()):
self.assertEqual(expected_color, display.color)

def test_shade_color(self) -> None:
original_color = "#F0F00F"
parameters = [
(0.0, "#F0F00F"), # unchanged
(-1.0, "#000000"), # full shade to black
(1.0, "#FFFFFF"), # full shade to white
(-0.5, "#787808"), # half shade to black
(0.5, "#F8F887"), # half shade to white
]

for alpha, expected_color in parameters:
with self.subTest(alpha=alpha):
color = shade_color(original_color, alpha)
self.assertEqual(expected_color.lower(), color.lower())


if __name__ == "__main__":
unittest.main()

0 comments on commit 5f94bdb

Please sign in to comment.