From ec013c2183870d7e20e2adc58cacc99556bc3aa8 Mon Sep 17 00:00:00 2001 From: Pieter Robberechts Date: Tue, 17 Dec 2024 10:36:02 +0100 Subject: [PATCH] refactor: remove numpy dependency from SkillCorner --- .../infra/serializers/tracking/skillcorner.py | 76 ++++++++++--------- 1 file changed, 40 insertions(+), 36 deletions(-) diff --git a/kloppy/infra/serializers/tracking/skillcorner.py b/kloppy/infra/serializers/tracking/skillcorner.py index b5cc0306..65cfd4db 100644 --- a/kloppy/infra/serializers/tracking/skillcorner.py +++ b/kloppy/infra/serializers/tracking/skillcorner.py @@ -1,15 +1,14 @@ +import json import logging -from datetime import timedelta, timezone -from dateutil.parser import parse import warnings -from typing import NamedTuple, IO, Optional, Union, Dict -from collections import Counter -import numpy as np -import json +from collections import Counter, defaultdict +from datetime import timedelta, timezone from pathlib import Path +from typing import IO, Dict, NamedTuple, Optional, Union + +from dateutil.parser import parse from kloppy.domain import ( - attacking_direction_from_frame, AttackingDirection, DatasetFlag, Frame, @@ -18,6 +17,7 @@ Orientation, Period, Player, + PlayerData, Point, Point3D, PositionType, @@ -25,7 +25,7 @@ Score, Team, TrackingDataset, - PlayerData, + attacking_direction_from_frame, ) from kloppy.infra.serializers.tracking.deserializer import ( TrackingDataDeserializer, @@ -207,22 +207,21 @@ def _get_skillcorner_attacking_directions(cls, frames, periods): x-coords might not reflect the attacking direction. """ attacking_directions = {} - frame_period_ids = np.array([_frame.period.id for _frame in frames]) - frame_attacking_directions = np.array( - [ - attacking_direction_from_frame(frame) - if len(frame.players_data) > 0 - else AttackingDirection.NOT_SET - for frame in frames - ] - ) + # Group attacking directions by period ID + period_direction_map = defaultdict(list) + for frame in frames: + if len(frame.players_data) > 0: + direction = attacking_direction_from_frame(frame) + else: + direction = AttackingDirection.NOT_SET + period_direction_map[frame.period.id].append(direction) + + # Determine the most common attacking direction for each period for period_id in periods.keys(): - if period_id in frame_period_ids: - count = Counter( - frame_attacking_directions[frame_period_ids == period_id] - ) - attacking_directions[period_id] = count.most_common()[0][0] + if period_id in period_direction_map: + count = Counter(period_direction_map[period_id]) + attacking_directions[period_id] = count.most_common(1)[0][0] else: attacking_directions[period_id] = AttackingDirection.NOT_SET @@ -252,28 +251,33 @@ def __get_periods(cls, tracking): """gets the Periods contained in the tracking data""" periods = {} - _periods = np.array([f["period"] for f in tracking]) - unique_periods = set(_periods) - unique_periods = [ - period for period in unique_periods if period is not None - ] + # Extract unique periods while filtering out None values + unique_periods = { + frame["period"] + for frame in tracking + if frame["period"] is not None + } for period in unique_periods: + # Filter frames that belong to the current period and have valid "time" _frames = [ frame for frame in tracking if frame["period"] == period and frame["time"] is not None ] - periods[period] = Period( - id=period, - start_timestamp=timedelta( - seconds=_frames[0]["frame"] / frame_rate - ), - end_timestamp=timedelta( - seconds=_frames[-1]["frame"] / frame_rate - ), - ) + # Ensure _frames is not empty before accessing the first and last elements + if _frames: + periods[period] = Period( + id=period, + start_timestamp=timedelta( + seconds=_frames[0]["frame"] / frame_rate + ), + end_timestamp=timedelta( + seconds=_frames[-1]["frame"] / frame_rate + ), + ) + return periods @classmethod