From 222d287b7fdc68ff5891f8d8124cc81603c16326 Mon Sep 17 00:00:00 2001 From: YashviMehta03 <144887977+YashviMehta03@users.noreply.github.com> Date: Fri, 24 Jan 2025 14:58:48 +0530 Subject: [PATCH] [DOC] added type hints to 'classification->convolution_based' module (#2494) * added type hints to classification->convolution_based * Automatic `pre-commit` fixes --- .../classification/convolution_based/_arsenal.py | 16 ++++++++-------- aeon/classification/convolution_based/_hydra.py | 7 ++++++- .../convolution_based/_minirocket.py | 6 +++--- .../convolution_based/_mr_hydra.py | 7 ++++++- .../convolution_based/_multirocket.py | 8 ++++---- aeon/classification/convolution_based/_rocket.py | 4 ++-- 6 files changed, 29 insertions(+), 19 deletions(-) diff --git a/aeon/classification/convolution_based/_arsenal.py b/aeon/classification/convolution_based/_arsenal.py index 4a232e7a1e..8a3d42ee75 100644 --- a/aeon/classification/convolution_based/_arsenal.py +++ b/aeon/classification/convolution_based/_arsenal.py @@ -130,15 +130,15 @@ class Arsenal(BaseClassifier): def __init__( self, - n_kernels=2000, - n_estimators=25, - rocket_transform="rocket", - max_dilations_per_kernel=32, - n_features_per_kernel=4, - time_limit_in_minutes=0.0, - contract_max_n_estimators=100, + n_kernels: int = 2000, + n_estimators: int = 25, + rocket_transform: str = "rocket", + max_dilations_per_kernel: int = 32, + n_features_per_kernel: int = 4, + time_limit_in_minutes: float = 0.0, + contract_max_n_estimators: int = 100, class_weight=None, - n_jobs=1, + n_jobs: int = 1, random_state=None, ): self.n_kernels = n_kernels diff --git a/aeon/classification/convolution_based/_hydra.py b/aeon/classification/convolution_based/_hydra.py index 5a3c2d443d..2c890d4a69 100644 --- a/aeon/classification/convolution_based/_hydra.py +++ b/aeon/classification/convolution_based/_hydra.py @@ -96,7 +96,12 @@ class HydraClassifier(BaseClassifier): } def __init__( - self, n_kernels=8, n_groups=64, class_weight=None, n_jobs=1, random_state=None + self, + n_kernels: int = 8, + n_groups: int = 64, + class_weight=None, + n_jobs: int = 1, + random_state=None, ): self.n_kernels = n_kernels self.n_groups = n_groups diff --git a/aeon/classification/convolution_based/_minirocket.py b/aeon/classification/convolution_based/_minirocket.py index 6025121354..629c447c1c 100644 --- a/aeon/classification/convolution_based/_minirocket.py +++ b/aeon/classification/convolution_based/_minirocket.py @@ -89,11 +89,11 @@ class MiniRocketClassifier(BaseClassifier): def __init__( self, - n_kernels=10000, - max_dilations_per_kernel=32, + n_kernels: int = 10000, + max_dilations_per_kernel: int = 32, estimator=None, class_weight=None, - n_jobs=1, + n_jobs: int = 1, random_state=None, ): self.n_kernels = n_kernels diff --git a/aeon/classification/convolution_based/_mr_hydra.py b/aeon/classification/convolution_based/_mr_hydra.py index 00c89599be..04384055ec 100644 --- a/aeon/classification/convolution_based/_mr_hydra.py +++ b/aeon/classification/convolution_based/_mr_hydra.py @@ -89,7 +89,12 @@ class MultiRocketHydraClassifier(BaseClassifier): } def __init__( - self, n_kernels=8, n_groups=64, class_weight=None, n_jobs=1, random_state=None + self, + n_kernels: int = 8, + n_groups: int = 64, + class_weight=None, + n_jobs: int = 1, + random_state=None, ): self.n_kernels = n_kernels self.n_groups = n_groups diff --git a/aeon/classification/convolution_based/_multirocket.py b/aeon/classification/convolution_based/_multirocket.py index 0da780c76d..791cf8fe55 100644 --- a/aeon/classification/convolution_based/_multirocket.py +++ b/aeon/classification/convolution_based/_multirocket.py @@ -90,12 +90,12 @@ class MultiRocketClassifier(BaseClassifier): def __init__( self, - n_kernels=10000, - max_dilations_per_kernel=32, - n_features_per_kernel=4, + n_kernels: int = 10000, + max_dilations_per_kernel: int = 32, + n_features_per_kernel: int = 4, estimator=None, class_weight=None, - n_jobs=1, + n_jobs: int = 1, random_state=None, ): self.n_kernels = n_kernels diff --git a/aeon/classification/convolution_based/_rocket.py b/aeon/classification/convolution_based/_rocket.py index 152397a940..c6cf304c0c 100644 --- a/aeon/classification/convolution_based/_rocket.py +++ b/aeon/classification/convolution_based/_rocket.py @@ -93,10 +93,10 @@ class RocketClassifier(BaseClassifier): def __init__( self, - n_kernels=10000, + n_kernels: int = 10000, estimator=None, class_weight=None, - n_jobs=1, + n_jobs: int = 1, random_state=None, ): self.n_kernels = n_kernels