From 0a4a59c5603dcd7b9377a6f1327f0c5809e31a1d Mon Sep 17 00:00:00 2001 From: "UnravelSports [JB]" Date: Mon, 22 Jul 2024 14:10:29 +0200 Subject: [PATCH] formatting --- tests/test_kloppy.py | 18 ++++++++++-------- unravel/utils/objects/default_ball.py | 7 ++++--- unravel/utils/objects/default_player.py | 9 +++++---- unravel/utils/objects/default_tracking.py | 3 +-- unravel/utils/utils.py | 2 +- 5 files changed, 21 insertions(+), 18 deletions(-) diff --git a/tests/test_kloppy.py b/tests/test_kloppy.py index 8fc0186..c34732b 100644 --- a/tests/test_kloppy.py +++ b/tests/test_kloppy.py @@ -117,17 +117,19 @@ def test_conversion(self, gnnc: GraphConverter): assert data.orientation == Orientation.STATIC_HOME_AWAY assert data.attacking_players == data.home_players assert data.defending_players == data.away_players - + hp = data.home_players[3] - assert -19.582426479899993 == pytest.approx(hp.x1, abs=1e-5) - assert 24.3039460863 == pytest.approx(hp.y1, abs=1e-5) - assert -19.6022318885 == pytest.approx(hp.x2, abs=1e-5) - assert 24.1632567814 == pytest.approx(hp.y2, abs=1e-5) - assert hp.position.shape == (2, ) - np.testing.assert_allclose(hp.position, np.asarray([hp.x1, hp.y1]), rtol=1e-4, atol=1e-4) + assert -19.582426479899993 == pytest.approx(hp.x1, abs=1e-5) + assert 24.3039460863 == pytest.approx(hp.y1, abs=1e-5) + assert -19.6022318885 == pytest.approx(hp.x2, abs=1e-5) + assert 24.1632567814 == pytest.approx(hp.y2, abs=1e-5) + assert hp.position.shape == (2,) + np.testing.assert_allclose( + hp.position, np.asarray([hp.x1, hp.y1]), rtol=1e-4, atol=1e-4 + ) assert hp.is_gk == False assert hp.next_position[0] - hp.position[0] - + assert data.ball_carrier_idx == 1 assert len(data.home_players) == 6 assert len(data.away_players) == 4 diff --git a/unravel/utils/objects/default_ball.py b/unravel/utils/objects/default_ball.py index 66d46ac..d5817f4 100644 --- a/unravel/utils/objects/default_ball.py +++ b/unravel/utils/objects/default_ball.py @@ -24,8 +24,10 @@ def __post_init__(self): def set_velocity(self): delta_time = 1.0 / self.fps - - if not (np.any(np.isnan(self.next_position3D)) or np.any(np.isnan(self.position3D))): + + if not ( + np.any(np.isnan(self.next_position3D)) or np.any(np.isnan(self.position3D)) + ): vx = (self.next_position3D[0] - self.position3D[0]) / delta_time vy = (self.next_position3D[1] - self.position3D[1]) / delta_time vz = (self.next_position3D[2] - self.position3D[2]) / delta_time @@ -43,7 +45,6 @@ def set_velocity(self): self.speed = np.sqrt(vx**2 + vy**2 + vz**2) - def invert_position(self): self.next_position = self.next_position * -1.0 self.position = self.position * -1.0 diff --git a/unravel/utils/objects/default_player.py b/unravel/utils/objects/default_player.py index ccb17eb..db18e2a 100644 --- a/unravel/utils/objects/default_player.py +++ b/unravel/utils/objects/default_player.py @@ -22,7 +22,6 @@ class DefaultPlayer(object): ) # velocity vector speed: float = 0.0 # actual speed in m/s is_gk: bool = False - def __post_init__(self): self.next_position = np.asarray([self.x2, self.y2], dtype=float) @@ -42,7 +41,9 @@ def invert_position(self): def set_velocity(self): dt = 1.0 / self.fps - if not (np.any(np.isnan(self.next_position)) or np.any(np.isnan(self.position))): + if not ( + np.any(np.isnan(self.next_position)) or np.any(np.isnan(self.position)) + ): vx = (self.next_position[0] - self.position[0]) / dt vy = (self.next_position[1] - self.position[1]) / dt else: @@ -50,9 +51,9 @@ def set_velocity(self): vy = 0 self.velocity = np.asarray([vx, vy], dtype=float) - + # Re-check if any component of velocity is NaN and set to zero if it is if np.any(np.isnan(self.velocity)): self.velocity = np.asarray([0.0, 0.0], dtype=float) - + self.speed = np.sqrt(vx**2 + vy**2) diff --git a/unravel/utils/objects/default_tracking.py b/unravel/utils/objects/default_tracking.py index ddab12d..a06dfad 100644 --- a/unravel/utils/objects/default_tracking.py +++ b/unravel/utils/objects/default_tracking.py @@ -35,7 +35,6 @@ class DefaultTrackingModel: ball_carrier_treshold: bool = 25.0 verbose: bool = False pad_n_players: bool = None - def __post_init__(self): self.home_players: List[DefaultPlayer] = list() @@ -207,7 +206,7 @@ def set_objects_from_frame( y1=coords.y, y2=next_coords.y, is_visible=True, - fps=self.fps + fps=self.fps, ) if pid.team.ground == Ground.HOME: diff --git a/unravel/utils/utils.py b/unravel/utils/utils.py index 67f8fe3..7ed098a 100644 --- a/unravel/utils/utils.py +++ b/unravel/utils/utils.py @@ -23,7 +23,7 @@ def dummy_graph_ids(dataset: TrackingDataset) -> Dict: """ if not isinstance(dataset, TrackingDataset): raise TypeError("dataset should be of type TrackingDataset (from kloppy)") - + from uuid import uuid4 graph_ids = dict()