From edee570a0ac66b9e2f9ac5bbde64ab1b3b4bc54c Mon Sep 17 00:00:00 2001 From: Pieter Robberechts Date: Tue, 17 Dec 2024 21:01:59 +0100 Subject: [PATCH] refactor: remove numpy dependency from SkillCorner (#375) --- .../infra/serializers/tracking/skillcorner.py | 64 ++++++++++--------- 1 file changed, 33 insertions(+), 31 deletions(-) diff --git a/kloppy/infra/serializers/tracking/skillcorner.py b/kloppy/infra/serializers/tracking/skillcorner.py index f819a5af..32e2d670 100644 --- a/kloppy/infra/serializers/tracking/skillcorner.py +++ b/kloppy/infra/serializers/tracking/skillcorner.py @@ -1,13 +1,11 @@ import json import logging import warnings -from collections import Counter +from collections import Counter, defaultdict from datetime import datetime, timedelta, timezone from pathlib import Path from typing import IO, Dict, NamedTuple, Optional, Union -import numpy as np - from kloppy.domain import ( AttackingDirection, DatasetFlag, @@ -207,22 +205,21 @@ def _get_skillcorner_attacking_directions(cls, frames, periods): x-coords might not reflect the attacking direction. """ attacking_directions = {} - frame_period_ids = np.array([_frame.period.id for _frame in frames]) - frame_attacking_directions = np.array( - [ - attacking_direction_from_frame(frame) - if len(frame.players_data) > 0 - else AttackingDirection.NOT_SET - for frame in frames - ] - ) + # Group attacking directions by period ID + period_direction_map = defaultdict(list) + for frame in frames: + if len(frame.players_data) > 0: + direction = attacking_direction_from_frame(frame) + else: + direction = AttackingDirection.NOT_SET + period_direction_map[frame.period.id].append(direction) + + # Determine the most common attacking direction for each period for period_id in periods.keys(): - if period_id in frame_period_ids: - count = Counter( - frame_attacking_directions[frame_period_ids == period_id] - ) - attacking_directions[period_id] = count.most_common()[0][0] + if period_id in period_direction_map: + count = Counter(period_direction_map[period_id]) + attacking_directions[period_id] = count.most_common(1)[0][0] else: attacking_directions[period_id] = AttackingDirection.NOT_SET @@ -252,28 +249,33 @@ def __get_periods(cls, tracking): """gets the Periods contained in the tracking data""" periods = {} - _periods = np.array([f["period"] for f in tracking]) - unique_periods = set(_periods) - unique_periods = [ - period for period in unique_periods if period is not None - ] + # Extract unique periods while filtering out None values + unique_periods = { + frame["period"] + for frame in tracking + if frame["period"] is not None + } for period in unique_periods: + # Filter frames that belong to the current period and have valid "time" _frames = [ frame for frame in tracking if frame["period"] == period and frame["time"] is not None ] - periods[period] = Period( - id=period, - start_timestamp=timedelta( - seconds=_frames[0]["frame"] / frame_rate - ), - end_timestamp=timedelta( - seconds=_frames[-1]["frame"] / frame_rate - ), - ) + # Ensure _frames is not empty before accessing the first and last elements + if _frames: + periods[period] = Period( + id=period, + start_timestamp=timedelta( + seconds=_frames[0]["frame"] / frame_rate + ), + end_timestamp=timedelta( + seconds=_frames[-1]["frame"] / frame_rate + ), + ) + return periods @classmethod