Skip to content

Commit

Permalink
List -> list
Browse files Browse the repository at this point in the history
  • Loading branch information
beardyFace committed Oct 9, 2024
1 parent 60a8084 commit 8e0973d
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions cares_reinforcement_learning/util/configurations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Optional

from pydantic import BaseModel, Field

Expand All @@ -10,7 +10,7 @@

# pylint disbale-next=unused-import

# NOTE: If a parameter is a list then don't wrap with Optional leave as implicit optional - List[type] = default
# NOTE: If a parameter is a list then don't wrap with Optional leave as implicit optional - list[type] = default


class SubscriptableClass(BaseModel):
Expand All @@ -27,14 +27,14 @@ class TrainingConfig(SubscriptableClass):
Configuration class for training.
Attributes:
seeds (List[int]): List of random seeds for reproducibility. Default is [10].
seeds (list[int]): list of random seeds for reproducibility. Default is [10].
plot_frequency (Optional[int]): Frequency at which to plot training progress. Default is 100.
checkpoint_frequency (Optional[int]): Frequency at which to save model checkpoints. Default is 100.
number_steps_per_evaluation (Optional[int]): Number of steps per evaluation. Default is 10000.
number_eval_episodes (Optional[int]): Number of episodes to evaluate during training. Default is 10.
"""

seeds: List[int] = [10]
seeds: list[int] = [10]
plot_frequency: Optional[int] = 100
checkpoint_frequency: Optional[int] = 100
number_steps_per_evaluation: Optional[int] = 10000
Expand Down Expand Up @@ -63,7 +63,7 @@ class AlgorithmConfig(SubscriptableClass):
image_observation (Optional[int]): Whether the observation is an image.
hidden_size (List[int]): List of hidden layer sizes - e.g. [256, 256].
hidden_size (list[int]): list of hidden layer sizes - e.g. [256, 256].
"""

algorithm: str = Field(description="Name of the algorithm to be used")
Expand All @@ -81,7 +81,7 @@ class AlgorithmConfig(SubscriptableClass):

image_observation: Optional[int] = 0

hidden_size: List[int] = None
hidden_size: list[int] | None = None


class DQNConfig(AlgorithmConfig):
Expand Down Expand Up @@ -181,7 +181,7 @@ class SACConfig(AlgorithmConfig):
tau: Optional[float] = 0.005
reward_scale: Optional[float] = 1.0

log_std_bounds: List[float] = [-20, 2]
log_std_bounds: list[float] = [-20, 2]


class SACAEConfig(AlgorithmConfig):
Expand All @@ -198,7 +198,7 @@ class SACAEConfig(AlgorithmConfig):
tau: Optional[float] = 0.005
reward_scale: Optional[float] = 1.0

log_std_bounds: List[float] = [-20, 2]
log_std_bounds: list[float] = [-20, 2]

encoder_tau: Optional[float] = 0.05
decoder_update_freq: Optional[int] = 1
Expand Down Expand Up @@ -248,7 +248,7 @@ class DynaSACConfig(AlgorithmConfig):
gamma: Optional[float] = 0.99
tau: Optional[float] = 0.005

log_std_bounds: List[float] = [-20, 2]
log_std_bounds: list[float] = [-20, 2]

horizon: Optional[int] = 3
num_samples: Optional[int] = 10
Expand Down Expand Up @@ -318,7 +318,7 @@ class TQCConfig(AlgorithmConfig):
num_quantiles: Optional[int] = 25
num_nets: Optional[int] = 5

log_std_bounds: List[float] = [-20, 2]
log_std_bounds: list[float] = [-20, 2]


class CTD4Config(AlgorithmConfig):
Expand Down Expand Up @@ -362,7 +362,7 @@ class PERSACConfig(AlgorithmConfig):
per_alpha: Optional[float] = 0.6
min_priority: Optional[float] = 1e-6

log_std_bounds: List[float] = [-20, 2]
log_std_bounds: list[float] = [-20, 2]


class LAPTD3Config(AlgorithmConfig):
Expand Down Expand Up @@ -391,7 +391,7 @@ class LAPSACConfig(AlgorithmConfig):
reward_scale: Optional[float] = 1.0
min_priority: Optional[float] = 1.0

log_std_bounds: List[float] = [-20, 2]
log_std_bounds: list[float] = [-20, 2]


class PALTD3Config(AlgorithmConfig):
Expand Down Expand Up @@ -436,7 +436,7 @@ class LA3PSACConfig(AlgorithmConfig):
min_priority: Optional[float] = 1.0
prioritized_fraction: Optional[float] = 0.5

log_std_bounds: List[float] = [-20, 2]
log_std_bounds: list[float] = [-20, 2]


class MAPERTD3Config(AlgorithmConfig):
Expand Down Expand Up @@ -477,8 +477,8 @@ class MAPERSACConfig(AlgorithmConfig):
G: Optional[int] = 64
number_steps_per_train_policy: Optional[int] = 64

hidden_size: List[int] = [400, 300]
log_std_bounds: List[float] = [-20, 2]
hidden_size: list[int] = [400, 300]
log_std_bounds: list[float] = [-20, 2]


class RDTD3Config(AlgorithmConfig):
Expand Down Expand Up @@ -506,4 +506,4 @@ class RDSACConfig(AlgorithmConfig):
per_alpha: Optional[float] = 0.7
min_priority: Optional[float] = 1.0

log_std_bounds: List[float] = [-20, 2]
log_std_bounds: list[float] = [-20, 2]

0 comments on commit 8e0973d

Please sign in to comment.