Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SkillCorner] Remove numpy dependency #375

Merged
merged 2 commits into from
Dec 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 33 additions & 31 deletions kloppy/infra/serializers/tracking/skillcorner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import json
import logging
import warnings
from collections import Counter
from collections import Counter, defaultdict
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import IO, Dict, NamedTuple, Optional, Union

import numpy as np

from kloppy.domain import (
AttackingDirection,
DatasetFlag,
Expand Down Expand Up @@ -207,22 +205,21 @@ def _get_skillcorner_attacking_directions(cls, frames, periods):
x-coords might not reflect the attacking direction.
"""
attacking_directions = {}
frame_period_ids = np.array([_frame.period.id for _frame in frames])
frame_attacking_directions = np.array(
[
attacking_direction_from_frame(frame)
if len(frame.players_data) > 0
else AttackingDirection.NOT_SET
for frame in frames
]
)

# Group attacking directions by period ID
period_direction_map = defaultdict(list)
for frame in frames:
if len(frame.players_data) > 0:
direction = attacking_direction_from_frame(frame)
else:
direction = AttackingDirection.NOT_SET
period_direction_map[frame.period.id].append(direction)

# Determine the most common attacking direction for each period
for period_id in periods.keys():
if period_id in frame_period_ids:
count = Counter(
frame_attacking_directions[frame_period_ids == period_id]
)
attacking_directions[period_id] = count.most_common()[0][0]
if period_id in period_direction_map:
count = Counter(period_direction_map[period_id])
attacking_directions[period_id] = count.most_common(1)[0][0]
else:
attacking_directions[period_id] = AttackingDirection.NOT_SET

Expand Down Expand Up @@ -252,28 +249,33 @@ def __get_periods(cls, tracking):
"""gets the Periods contained in the tracking data"""
periods = {}

_periods = np.array([f["period"] for f in tracking])
unique_periods = set(_periods)
unique_periods = [
period for period in unique_periods if period is not None
]
# Extract unique periods while filtering out None values
unique_periods = {
frame["period"]
for frame in tracking
if frame["period"] is not None
}

for period in unique_periods:
# Filter frames that belong to the current period and have valid "time"
_frames = [
frame
for frame in tracking
if frame["period"] == period and frame["time"] is not None
]

periods[period] = Period(
id=period,
start_timestamp=timedelta(
seconds=_frames[0]["frame"] / frame_rate
),
end_timestamp=timedelta(
seconds=_frames[-1]["frame"] / frame_rate
),
)
# Ensure _frames is not empty before accessing the first and last elements
if _frames:
periods[period] = Period(
id=period,
start_timestamp=timedelta(
seconds=_frames[0]["frame"] / frame_rate
),
end_timestamp=timedelta(
seconds=_frames[-1]["frame"] / frame_rate
),
)

return periods

@classmethod
Expand Down
Loading