Skip to content

Commit

Permalink
Add mypy type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
cmalinmayor committed Feb 11, 2025
1 parent 1f67406 commit f419d66
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 70 deletions.
5 changes: 5 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,8 @@ extra:
version:
provider: mike
default: latest

plugins:
- search
- mkdocs-autoapi
- mkdocstrings
89 changes: 87 additions & 2 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ dependencies =[
]
[project.optional-dependencies]
testing =["pytest", "pytest-cov"]
docs = ["mkdocs-material", "mike"]
dev = ["ruff", "pre-commit"]
docs = ["mkdocs-material", "mike", "mkdocstrings[python]", "mkdocs-autoapi"]
dev = ["ruff", "pre-commit", "mypy"]
all = ["funtracks[testing,docs,dev]"]

[project.urls]
Expand Down Expand Up @@ -102,3 +102,7 @@ testing = { features = ["testing"], solve-group = "default" }

[tool.pixi.feature.testing.tasks]
test = "pytest --cov=funtracks tests/"


[tool.pixi.feature.docs.tasks]
localdocs = "mike build"
2 changes: 1 addition & 1 deletion src/funtracks/data_model/action_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ActionHistory:
the undone actions)
"""

def __init__(self):
def __init__(self) -> None:
self.undo_stack: list[TracksAction] = [] # list of actions that can be undone
self.redo_stack: list[TracksAction] = [] # list of actions that can be redone

Expand Down
33 changes: 21 additions & 12 deletions src/funtracks/data_model/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@

from __future__ import annotations

from typing import TYPE_CHECKING

from .graph_attributes import NodeAttr
from .solution_tracks import SolutionTracks
from .tracks import Attrs, Edge, Node, SegMask, Tracks

if TYPE_CHECKING:
from collections.abc import Iterable


class TracksAction:
def __init__(self, tracks: Tracks):
Expand Down Expand Up @@ -81,9 +86,9 @@ class AddNodes(TracksAction):
def __init__(
self,
tracks: Tracks,
nodes: list[Node],
nodes: Iterable[Node],
attributes: Attrs,
pixels: list[SegMask] | None = None,
pixels: Iterable[SegMask] | None = None,
):
"""Create an action to add new nodes, with optional segmentation
Expand All @@ -100,9 +105,9 @@ def __init__(
self.times = attributes.get(NodeAttr.TIME.value, None)
if NodeAttr.TIME.value in attributes:
del user_attrs[NodeAttr.TIME.value]
self.positions = attributes.get(tracks.pos_attr, None)
if tracks.pos_attr in attributes:
del user_attrs[tracks.pos_attr]
self.positions = attributes.get(NodeAttr.POS.value, None)
if NodeAttr.POS.value in attributes:
del user_attrs[NodeAttr.POS.value]
self.pixels = pixels
self.attributes = user_attrs
self._apply()
Expand All @@ -127,7 +132,10 @@ class DeleteNodes(TracksAction):
"""

def __init__(
self, tracks: Tracks, nodes: list[Node], pixels: list[SegMask] | None = None
self,
tracks: Tracks,
nodes: Iterable[Node],
pixels: Iterable[SegMask] | None = None,
):
super().__init__(tracks)
self.nodes = nodes
Expand Down Expand Up @@ -170,8 +178,8 @@ class UpdateNodeSegs(TracksAction):
def __init__(
self,
tracks: Tracks,
nodes: list[Node],
pixels: list[SegMask],
nodes: Iterable[Node],
pixels: Iterable[SegMask],
added: bool = True,
):
"""
Expand Down Expand Up @@ -211,13 +219,13 @@ class UpdateNodeAttrs(TracksAction):
def __init__(
self,
tracks: Tracks,
nodes: list[Node],
nodes: Iterable[Node],
attrs: Attrs,
):
"""
Args:
tracks (Tracks): The tracks to update the node attributes for
nodes (list[Node]): The nodes to update the attributes for
nodes (Iterable[Node]): The nodes to update the attributes for
attrs (Attrs): A mapping from attribute name to list of new attribute values
for the given nodes.
Expand Down Expand Up @@ -257,7 +265,7 @@ def _apply(self):
class AddEdges(TracksAction):
"""Action for adding new edges"""

def __init__(self, tracks: Tracks, edges: list[Edge]):
def __init__(self, tracks: Tracks, edges: Iterable[Edge]):
super().__init__(tracks)
self.edges = edges
self._apply()
Expand All @@ -277,7 +285,7 @@ def _apply(self):
class DeleteEdges(TracksAction):
"""Action for deleting edges"""

def __init__(self, tracks: Tracks, edges: list[Edge]):
def __init__(self, tracks: Tracks, edges: Iterable[Edge]):
super().__init__(tracks)
self.edges = edges
self._apply()
Expand All @@ -303,6 +311,7 @@ def __init__(self, tracks: SolutionTracks, start_node: Node, track_id: int):
track_id (int): The new track id to assign.
"""
super().__init__(tracks)
self.tracks: SolutionTracks = tracks
self.start_node = start_node
self.old_track_id = self.tracks.get_track_id(start_node)
self.new_track_id = track_id
Expand Down
14 changes: 7 additions & 7 deletions src/funtracks/data_model/graph_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ class NodeAttr(Enum):
implementations of commonly used ones, listed here.
"""

POS: str = "pos"
TIME: str = "time"
SEG_ID: str = "seg_id"
SEG_HYPO: str = "seg_hypo"
AREA: str = "area"
TRACK_ID: str = "track_id"
POS = "pos"
TIME = "time"
SEG_ID = "seg_id"
SEG_HYPO = "seg_hypo"
AREA = "area"
TRACK_ID = "track_id"


class EdgeAttr(Enum):
Expand All @@ -21,7 +21,7 @@ class EdgeAttr(Enum):
implementations of commonly used ones, listed here.
"""

IOU: str = "iou"
IOU = "iou"


class NodeType(Enum):
Expand Down
4 changes: 4 additions & 0 deletions src/funtracks/data_model/solution_tracks.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def add_nodes(
positions: np.ndarray | None = None,
attrs: Attrs | None = None,
):
if attrs is None:
raise ValueError(
f"Node attributes cannot be None, must contain {NodeAttr.TRACK_ID.value}"
)
# overriding add_nodes to add new nodes to the track_id_to_node mapping
super().add_nodes(nodes, times, positions, attrs)
for node, track_id in zip(nodes, attrs[NodeAttr.TRACK_ID.value], strict=True):
Expand Down
Loading

0 comments on commit f419d66

Please sign in to comment.