diff --git a/cares_reinforcement_learning/util/configurations.py b/cares_reinforcement_learning/util/configurations.py index b2463c6..a58a97f 100644 --- a/cares_reinforcement_learning/util/configurations.py +++ b/cares_reinforcement_learning/util/configurations.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from pydantic import BaseModel, Field @@ -10,7 +10,7 @@ # pylint disbale-next=unused-import -# NOTE: If a parameter is a list then don't wrap with Optional leave as implicit optional - List[type] = default +# NOTE: If a parameter is a list then don't wrap with Optional leave as implicit optional - list[type] = default class SubscriptableClass(BaseModel): @@ -27,14 +27,14 @@ class TrainingConfig(SubscriptableClass): Configuration class for training. Attributes: - seeds (List[int]): List of random seeds for reproducibility. Default is [10]. + seeds (list[int]): list of random seeds for reproducibility. Default is [10]. plot_frequency (Optional[int]): Frequency at which to plot training progress. Default is 100. checkpoint_frequency (Optional[int]): Frequency at which to save model checkpoints. Default is 100. number_steps_per_evaluation (Optional[int]): Number of steps per evaluation. Default is 10000. number_eval_episodes (Optional[int]): Number of episodes to evaluate during training. Default is 10. """ - seeds: List[int] = [10] + seeds: list[int] = [10] plot_frequency: Optional[int] = 100 checkpoint_frequency: Optional[int] = 100 number_steps_per_evaluation: Optional[int] = 10000 @@ -63,7 +63,7 @@ class AlgorithmConfig(SubscriptableClass): image_observation (Optional[int]): Whether the observation is an image. - hidden_size (List[int]): List of hidden layer sizes - e.g. [256, 256]. + hidden_size (list[int]): list of hidden layer sizes - e.g. [256, 256]. """ algorithm: str = Field(description="Name of the algorithm to be used") @@ -81,7 +81,7 @@ class AlgorithmConfig(SubscriptableClass): image_observation: Optional[int] = 0 - hidden_size: List[int] = None + hidden_size: list[int] | None = None class DQNConfig(AlgorithmConfig): @@ -181,7 +181,7 @@ class SACConfig(AlgorithmConfig): tau: Optional[float] = 0.005 reward_scale: Optional[float] = 1.0 - log_std_bounds: List[float] = [-20, 2] + log_std_bounds: list[float] = [-20, 2] class SACAEConfig(AlgorithmConfig): @@ -198,7 +198,7 @@ class SACAEConfig(AlgorithmConfig): tau: Optional[float] = 0.005 reward_scale: Optional[float] = 1.0 - log_std_bounds: List[float] = [-20, 2] + log_std_bounds: list[float] = [-20, 2] encoder_tau: Optional[float] = 0.05 decoder_update_freq: Optional[int] = 1 @@ -248,7 +248,7 @@ class DynaSACConfig(AlgorithmConfig): gamma: Optional[float] = 0.99 tau: Optional[float] = 0.005 - log_std_bounds: List[float] = [-20, 2] + log_std_bounds: list[float] = [-20, 2] horizon: Optional[int] = 3 num_samples: Optional[int] = 10 @@ -318,7 +318,7 @@ class TQCConfig(AlgorithmConfig): num_quantiles: Optional[int] = 25 num_nets: Optional[int] = 5 - log_std_bounds: List[float] = [-20, 2] + log_std_bounds: list[float] = [-20, 2] class CTD4Config(AlgorithmConfig): @@ -362,7 +362,7 @@ class PERSACConfig(AlgorithmConfig): per_alpha: Optional[float] = 0.6 min_priority: Optional[float] = 1e-6 - log_std_bounds: List[float] = [-20, 2] + log_std_bounds: list[float] = [-20, 2] class LAPTD3Config(AlgorithmConfig): @@ -391,7 +391,7 @@ class LAPSACConfig(AlgorithmConfig): reward_scale: Optional[float] = 1.0 min_priority: Optional[float] = 1.0 - log_std_bounds: List[float] = [-20, 2] + log_std_bounds: list[float] = [-20, 2] class PALTD3Config(AlgorithmConfig): @@ -436,7 +436,7 @@ class LA3PSACConfig(AlgorithmConfig): min_priority: Optional[float] = 1.0 prioritized_fraction: Optional[float] = 0.5 - log_std_bounds: List[float] = [-20, 2] + log_std_bounds: list[float] = [-20, 2] class MAPERTD3Config(AlgorithmConfig): @@ -477,8 +477,8 @@ class MAPERSACConfig(AlgorithmConfig): G: Optional[int] = 64 number_steps_per_train_policy: Optional[int] = 64 - hidden_size: List[int] = [400, 300] - log_std_bounds: List[float] = [-20, 2] + hidden_size: list[int] = [400, 300] + log_std_bounds: list[float] = [-20, 2] class RDTD3Config(AlgorithmConfig): @@ -506,4 +506,4 @@ class RDSACConfig(AlgorithmConfig): per_alpha: Optional[float] = 0.7 min_priority: Optional[float] = 1.0 - log_std_bounds: List[float] = [-20, 2] + log_std_bounds: list[float] = [-20, 2]