Skip to content

Commit

Permalink
refactor: remove numpy dependency from SkillCorner (#375)
Browse files Browse the repository at this point in the history
  • Loading branch information
probberechts authored Dec 17, 2024
1 parent b0f56e1 commit edee570
Showing 1 changed file with 33 additions and 31 deletions.
64 changes: 33 additions & 31 deletions kloppy/infra/serializers/tracking/skillcorner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import json
import logging
import warnings
from collections import Counter
from collections import Counter, defaultdict
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import IO, Dict, NamedTuple, Optional, Union

import numpy as np

from kloppy.domain import (
AttackingDirection,
DatasetFlag,
Expand Down Expand Up @@ -207,22 +205,21 @@ def _get_skillcorner_attacking_directions(cls, frames, periods):
x-coords might not reflect the attacking direction.
"""
attacking_directions = {}
frame_period_ids = np.array([_frame.period.id for _frame in frames])
frame_attacking_directions = np.array(
[
attacking_direction_from_frame(frame)
if len(frame.players_data) > 0
else AttackingDirection.NOT_SET
for frame in frames
]
)

# Group attacking directions by period ID
period_direction_map = defaultdict(list)
for frame in frames:
if len(frame.players_data) > 0:
direction = attacking_direction_from_frame(frame)
else:
direction = AttackingDirection.NOT_SET
period_direction_map[frame.period.id].append(direction)

# Determine the most common attacking direction for each period
for period_id in periods.keys():
if period_id in frame_period_ids:
count = Counter(
frame_attacking_directions[frame_period_ids == period_id]
)
attacking_directions[period_id] = count.most_common()[0][0]
if period_id in period_direction_map:
count = Counter(period_direction_map[period_id])
attacking_directions[period_id] = count.most_common(1)[0][0]
else:
attacking_directions[period_id] = AttackingDirection.NOT_SET

Expand Down Expand Up @@ -252,28 +249,33 @@ def __get_periods(cls, tracking):
"""gets the Periods contained in the tracking data"""
periods = {}

_periods = np.array([f["period"] for f in tracking])
unique_periods = set(_periods)
unique_periods = [
period for period in unique_periods if period is not None
]
# Extract unique periods while filtering out None values
unique_periods = {
frame["period"]
for frame in tracking
if frame["period"] is not None
}

for period in unique_periods:
# Filter frames that belong to the current period and have valid "time"
_frames = [
frame
for frame in tracking
if frame["period"] == period and frame["time"] is not None
]

periods[period] = Period(
id=period,
start_timestamp=timedelta(
seconds=_frames[0]["frame"] / frame_rate
),
end_timestamp=timedelta(
seconds=_frames[-1]["frame"] / frame_rate
),
)
# Ensure _frames is not empty before accessing the first and last elements
if _frames:
periods[period] = Period(
id=period,
start_timestamp=timedelta(
seconds=_frames[0]["frame"] / frame_rate
),
end_timestamp=timedelta(
seconds=_frames[-1]["frame"] / frame_rate
),
)

return periods

@classmethod
Expand Down

0 comments on commit edee570

Please sign in to comment.