diff --git a/kloppy/domain/models/tracking.py b/kloppy/domain/models/tracking.py index 940692f5..ffae5f11 100644 --- a/kloppy/domain/models/tracking.py +++ b/kloppy/domain/models/tracking.py @@ -45,8 +45,8 @@ class Trajectory: """ trackable_object: Union[Player, str] - start_frame: int - end_frame: int + start_frame: "Frame" + end_frame: "Frame" detections: List[Detection] def __iter__(self): @@ -158,19 +158,19 @@ def trajectories(self, trackable_object: Union[Player, Literal["ball"]]): # a new trajectory current_trajectory = Trajectory( trackable_object=trackable_object, - start_frame=frame.frame_id, - end_frame=frame.frame_id, + start_frame=frame, + end_frame=frame, detections=[detection], ) elif ( frame.prev_record is not None and frame.prev_record.frame_id - == current_trajectory.end_frame + == current_trajectory.end_frame.frame_id and frame.prev_record.period.id == frame.period.id ): # and it was tracked in the previous frame --> extend the # current trajectory - current_trajectory.end_frame = frame.frame_id + current_trajectory.end_frame = frame current_trajectory.detections.append(detection) else: # but a frame is missing or a new period started --> finish @@ -178,8 +178,8 @@ def trajectories(self, trackable_object: Union[Player, Literal["ball"]]): trajectories.append(current_trajectory) current_trajectory = Trajectory( trackable_object=trackable_object, - start_frame=frame.frame_id, - end_frame=frame.frame_id, + start_frame=frame, + end_frame=frame, detections=[detection], ) else: diff --git a/kloppy/tests/test_tracking.py b/kloppy/tests/test_tracking.py index e7a8710f..bfd801ba 100644 --- a/kloppy/tests/test_tracking.py +++ b/kloppy/tests/test_tracking.py @@ -184,8 +184,8 @@ def test_trajectories(self): ball_trajectories = dataset.trajectories("ball") assert len(ball_trajectories) == 1 - assert ball_trajectories[0].start_frame == 0 - assert ball_trajectories[0].end_frame == 124 + assert ball_trajectories[0].start_frame.frame_id == 0 + assert ball_trajectories[0].end_frame.frame_id == 124 assert len(ball_trajectories[0].detections) == 125 assert ball_trajectories[0].detections[0] == Detection( coordinates=Point3D(x=0, y=0, z=0) @@ -193,8 +193,8 @@ def test_trajectories(self): player_trajectories = dataset.trajectories("home_1") assert len(player_trajectories) == 2 - assert player_trajectories[0].start_frame == 0 - assert player_trajectories[0].end_frame == 9 + assert player_trajectories[0].start_frame.frame_id == 0 + assert player_trajectories[0].end_frame.frame_id == 9 assert len(player_trajectories[0].detections) == 10 assert player_trajectories[0].detections[0] == Detection( coordinates=Point(x=0, y=0)