Skip to content

Commit

Permalink
Link to frame in trajectories
Browse files Browse the repository at this point in the history
  • Loading branch information
probberechts committed Dec 18, 2024
1 parent 79bb6d9 commit 41871ba
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
16 changes: 8 additions & 8 deletions kloppy/domain/models/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class Trajectory:
"""

trackable_object: Union[Player, str]
start_frame: int
end_frame: int
start_frame: "Frame"
end_frame: "Frame"
detections: List[Detection]

def __iter__(self):
Expand Down Expand Up @@ -158,28 +158,28 @@ def trajectories(self, trackable_object: Union[Player, Literal["ball"]]):
# a new trajectory
current_trajectory = Trajectory(
trackable_object=trackable_object,
start_frame=frame.frame_id,
end_frame=frame.frame_id,
start_frame=frame,
end_frame=frame,
detections=[detection],
)
elif (
frame.prev_record is not None
and frame.prev_record.frame_id
== current_trajectory.end_frame
== current_trajectory.end_frame.frame_id
and frame.prev_record.period.id == frame.period.id
):
# and it was tracked in the previous frame --> extend the
# current trajectory
current_trajectory.end_frame = frame.frame_id
current_trajectory.end_frame = frame
current_trajectory.detections.append(detection)
else:
# but a frame is missing or a new period started --> finish
# the current trajectory and start a new one
trajectories.append(current_trajectory)
current_trajectory = Trajectory(
trackable_object=trackable_object,
start_frame=frame.frame_id,
end_frame=frame.frame_id,
start_frame=frame,
end_frame=frame,
detections=[detection],
)
else:
Expand Down
8 changes: 4 additions & 4 deletions kloppy/tests/test_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,17 +184,17 @@ def test_trajectories(self):

ball_trajectories = dataset.trajectories("ball")
assert len(ball_trajectories) == 1
assert ball_trajectories[0].start_frame == 0
assert ball_trajectories[0].end_frame == 124
assert ball_trajectories[0].start_frame.frame_id == 0
assert ball_trajectories[0].end_frame.frame_id == 124
assert len(ball_trajectories[0].detections) == 125
assert ball_trajectories[0].detections[0] == Detection(
coordinates=Point3D(x=0, y=0, z=0)
)

player_trajectories = dataset.trajectories("home_1")
assert len(player_trajectories) == 2
assert player_trajectories[0].start_frame == 0
assert player_trajectories[0].end_frame == 9
assert player_trajectories[0].start_frame.frame_id == 0
assert player_trajectories[0].end_frame.frame_id == 9
assert len(player_trajectories[0].detections) == 10
assert player_trajectories[0].detections[0] == Detection(
coordinates=Point(x=0, y=0)
Expand Down

0 comments on commit 41871ba

Please sign in to comment.