Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
UnravelSports [JB] committed Jul 22, 2024
1 parent 64462e6 commit 0a4a59c
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 18 deletions.
18 changes: 10 additions & 8 deletions tests/test_kloppy.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,19 @@ def test_conversion(self, gnnc: GraphConverter):
assert data.orientation == Orientation.STATIC_HOME_AWAY
assert data.attacking_players == data.home_players
assert data.defending_players == data.away_players

hp = data.home_players[3]
assert -19.582426479899993 == pytest.approx(hp.x1, abs=1e-5)
assert 24.3039460863 == pytest.approx(hp.y1, abs=1e-5)
assert -19.6022318885 == pytest.approx(hp.x2, abs=1e-5)
assert 24.1632567814 == pytest.approx(hp.y2, abs=1e-5)
assert hp.position.shape == (2, )
np.testing.assert_allclose(hp.position, np.asarray([hp.x1, hp.y1]), rtol=1e-4, atol=1e-4)
assert -19.582426479899993 == pytest.approx(hp.x1, abs=1e-5)
assert 24.3039460863 == pytest.approx(hp.y1, abs=1e-5)
assert -19.6022318885 == pytest.approx(hp.x2, abs=1e-5)
assert 24.1632567814 == pytest.approx(hp.y2, abs=1e-5)
assert hp.position.shape == (2,)
np.testing.assert_allclose(
hp.position, np.asarray([hp.x1, hp.y1]), rtol=1e-4, atol=1e-4
)
assert hp.is_gk == False
assert hp.next_position[0] - hp.position[0]

assert data.ball_carrier_idx == 1
assert len(data.home_players) == 6
assert len(data.away_players) == 4
Expand Down
7 changes: 4 additions & 3 deletions unravel/utils/objects/default_ball.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ def __post_init__(self):

def set_velocity(self):
delta_time = 1.0 / self.fps

if not (np.any(np.isnan(self.next_position3D)) or np.any(np.isnan(self.position3D))):

if not (
np.any(np.isnan(self.next_position3D)) or np.any(np.isnan(self.position3D))
):
vx = (self.next_position3D[0] - self.position3D[0]) / delta_time
vy = (self.next_position3D[1] - self.position3D[1]) / delta_time
vz = (self.next_position3D[2] - self.position3D[2]) / delta_time
Expand All @@ -43,7 +45,6 @@ def set_velocity(self):

self.speed = np.sqrt(vx**2 + vy**2 + vz**2)


def invert_position(self):
self.next_position = self.next_position * -1.0
self.position = self.position * -1.0
Expand Down
9 changes: 5 additions & 4 deletions unravel/utils/objects/default_player.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class DefaultPlayer(object):
) # velocity vector
speed: float = 0.0 # actual speed in m/s
is_gk: bool = False


def __post_init__(self):
self.next_position = np.asarray([self.x2, self.y2], dtype=float)
Expand All @@ -42,17 +41,19 @@ def invert_position(self):

def set_velocity(self):
dt = 1.0 / self.fps
if not (np.any(np.isnan(self.next_position)) or np.any(np.isnan(self.position))):
if not (
np.any(np.isnan(self.next_position)) or np.any(np.isnan(self.position))
):
vx = (self.next_position[0] - self.position[0]) / dt
vy = (self.next_position[1] - self.position[1]) / dt
else:
vx = 0
vy = 0

self.velocity = np.asarray([vx, vy], dtype=float)

# Re-check if any component of velocity is NaN and set to zero if it is
if np.any(np.isnan(self.velocity)):
self.velocity = np.asarray([0.0, 0.0], dtype=float)

self.speed = np.sqrt(vx**2 + vy**2)
3 changes: 1 addition & 2 deletions unravel/utils/objects/default_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class DefaultTrackingModel:
ball_carrier_treshold: bool = 25.0
verbose: bool = False
pad_n_players: bool = None


def __post_init__(self):
self.home_players: List[DefaultPlayer] = list()
Expand Down Expand Up @@ -207,7 +206,7 @@ def set_objects_from_frame(
y1=coords.y,
y2=next_coords.y,
is_visible=True,
fps=self.fps
fps=self.fps,
)

if pid.team.ground == Ground.HOME:
Expand Down
2 changes: 1 addition & 1 deletion unravel/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def dummy_graph_ids(dataset: TrackingDataset) -> Dict:
"""
if not isinstance(dataset, TrackingDataset):
raise TypeError("dataset should be of type TrackingDataset (from kloppy)")

from uuid import uuid4

graph_ids = dict()
Expand Down

0 comments on commit 0a4a59c

Please sign in to comment.