Skip to content

Commit

Permalink
refactor: remove numpy dependency from SkillCorner
Browse files Browse the repository at this point in the history
  • Loading branch information
probberechts committed Dec 17, 2024
1 parent dff0204 commit ec013c2
Showing 1 changed file with 40 additions and 36 deletions.
76 changes: 40 additions & 36 deletions kloppy/infra/serializers/tracking/skillcorner.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import json
import logging
from datetime import timedelta, timezone
from dateutil.parser import parse
import warnings
from typing import NamedTuple, IO, Optional, Union, Dict
from collections import Counter
import numpy as np
import json
from collections import Counter, defaultdict
from datetime import timedelta, timezone
from pathlib import Path
from typing import IO, Dict, NamedTuple, Optional, Union

from dateutil.parser import parse

from kloppy.domain import (
attacking_direction_from_frame,
AttackingDirection,
DatasetFlag,
Frame,
Expand All @@ -18,14 +17,15 @@
Orientation,
Period,
Player,
PlayerData,
Point,
Point3D,
PositionType,
Provider,
Score,
Team,
TrackingDataset,
PlayerData,
attacking_direction_from_frame,
)
from kloppy.infra.serializers.tracking.deserializer import (
TrackingDataDeserializer,
Expand Down Expand Up @@ -207,22 +207,21 @@ def _get_skillcorner_attacking_directions(cls, frames, periods):
x-coords might not reflect the attacking direction.
"""
attacking_directions = {}
frame_period_ids = np.array([_frame.period.id for _frame in frames])
frame_attacking_directions = np.array(
[
attacking_direction_from_frame(frame)
if len(frame.players_data) > 0
else AttackingDirection.NOT_SET
for frame in frames
]
)

# Group attacking directions by period ID
period_direction_map = defaultdict(list)
for frame in frames:
if len(frame.players_data) > 0:
direction = attacking_direction_from_frame(frame)
else:
direction = AttackingDirection.NOT_SET
period_direction_map[frame.period.id].append(direction)

# Determine the most common attacking direction for each period
for period_id in periods.keys():
if period_id in frame_period_ids:
count = Counter(
frame_attacking_directions[frame_period_ids == period_id]
)
attacking_directions[period_id] = count.most_common()[0][0]
if period_id in period_direction_map:
count = Counter(period_direction_map[period_id])
attacking_directions[period_id] = count.most_common(1)[0][0]
else:
attacking_directions[period_id] = AttackingDirection.NOT_SET

Expand Down Expand Up @@ -252,28 +251,33 @@ def __get_periods(cls, tracking):
"""gets the Periods contained in the tracking data"""
periods = {}

_periods = np.array([f["period"] for f in tracking])
unique_periods = set(_periods)
unique_periods = [
period for period in unique_periods if period is not None
]
# Extract unique periods while filtering out None values
unique_periods = {
frame["period"]
for frame in tracking
if frame["period"] is not None
}

for period in unique_periods:
# Filter frames that belong to the current period and have valid "time"
_frames = [
frame
for frame in tracking
if frame["period"] == period and frame["time"] is not None
]

periods[period] = Period(
id=period,
start_timestamp=timedelta(
seconds=_frames[0]["frame"] / frame_rate
),
end_timestamp=timedelta(
seconds=_frames[-1]["frame"] / frame_rate
),
)
# Ensure _frames is not empty before accessing the first and last elements
if _frames:
periods[period] = Period(
id=period,
start_timestamp=timedelta(
seconds=_frames[0]["frame"] / frame_rate
),
end_timestamp=timedelta(
seconds=_frames[-1]["frame"] / frame_rate
),
)

return periods

@classmethod
Expand Down

0 comments on commit ec013c2

Please sign in to comment.