Skip to content

Commit

Permalink
more doc strings
Browse files Browse the repository at this point in the history
  • Loading branch information
Tetracarbonylnickel committed Apr 10, 2024
1 parent 1fe6f2d commit 2fc94f4
Showing 1 changed file with 63 additions and 29 deletions.
92 changes: 63 additions & 29 deletions apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,18 +139,30 @@ class ModelConfig(BaseModel, extra="forbid"):
Parameters
----------
n_basis :
| Number of uncontracted gaussian basis functions.
n_radial :
| Number of contracted basis functions.
r_min :
| Position of the first uncontracted basis function's mean.
r_max :
| Cutoff radius of the descriptor.
nn :
| Number of hidden layers and units in those layers.
b_init :
| Initialization scheme for the neural network biases. Either `normal` or `zeros`.
n_basis : PositiveInt, default = 7
Number of uncontracted gaussian basis functions.
n_radial : PositiveInt, default = 5
Number of contracted basis functions.
r_min : NonNegativeFloat, default = 0.5
Position of the first uncontracted basis function's mean.
r_max : PositiveFloat, default = 6.0
Cutoff radius of the descriptor.
nn : List[PositiveInt], default = [512, 512]
Number of hidden layers and units in those layers.
b_init : Literal["normal", "zeros"], default = "normal"
Initialization scheme for the neural network biases.
emb_init : Optional[str], default = "uniform"
Initialization scheme for embedding layer weights.
use_zbl : bool, default = False
Whether to include the ZBL correction.
calc_stress : bool, default = False
Whether to calculate stress during model evaluation.
descriptor_dtype : Literal["fp32", "fp64"], default = "fp64"
Data type for descriptor calculations.
readout_dtype : Literal["fp32", "fp64"], default = "fp32"
Data type for readout calculations.
scale_shift_dtype : Literal["fp32", "fp64"], default = "fp32"
Data type for scale and shift parameters.
"""

n_basis: PositiveInt = 7
Expand Down Expand Up @@ -188,17 +200,27 @@ class OptimizerConfig(BaseModel, frozen=True, extra="forbid"):
"""
Configuration of the optimizer.
Learning rates of 0 will freeze the respective parameters.
Parameters
----------
opt_name: Name of the optimizer. Can be any `optax` optimizer.
emb_lr: Learning rate of the elemental embedding contraction coefficients.
nn_lr: Learning rate of the neural network parameters.
scale_lr: Learning rate of the elemental output scaling factors.
shift_lr: Learning rate of the elemental output shifts.
transition_begin: Number of training steps (not epochs) before the start of the
linear learning rate schedule.
opt_kwargs: Optimizer keyword arguments. Passed to the `optax` optimizer.
opt_name : str, default = "adam"
Name of the optimizer. Can be any `optax` optimizer.
emb_lr : NonNegativeFloat, default = 0.02
Learning rate of the elemental embedding contraction coefficients.
nn_lr : NonNegativeFloat, default = 0.03
Learning rate of the neural network parameters.
scale_lr : NonNegativeFloat, default = 0.001
Learning rate of the elemental output scaling factors.
shift_lr : NonNegativeFloat, default = 0.05
Learning rate of the elemental output shifts.
zbl_lr : NonNegativeFloat, default = 0.001
Learning rate of the ZBL correction parameters.
transition_begin : int, default = 0
Number of training steps (not epochs) before the start of the linear learning rate schedule.
opt_kwargs : dict, default = {}
Optimizer keyword arguments. Passed to the `optax` optimizer.
sam_rho : NonNegativeFloat, default = 0.0
Rho parameter for Sharpness-Aware Minimization.
"""

opt_name: str = "adam"
Expand All @@ -218,10 +240,11 @@ class MetricsConfig(BaseModel, extra="forbid"):
Parameters
----------
name: Keyword of the quantity e.g `energy`.
reductions: List of reductions performed on the difference between
target and predictions. Can be mae, mse, rmse for energies and forces.
For forces it is also possible to use `angle`.
name : str
Keyword of the quantity, e.g., 'energy'.
reductions : List[str]
List of reductions performed on the difference between target and predictions.
Can be 'mae', 'mse', 'rmse' for energies and forces. For forces, 'angle' can also be used.
"""

name: str
Expand All @@ -234,10 +257,21 @@ class LossConfig(BaseModel, extra="forbid"):
Parameters
----------
name: Keyword of the quantity e.g `energy`.
loss_type: Weighting scheme for atomic contributions. See the MLIP package
for reference 10.1088/2632-2153/abc9fe for details
weight: Weighting factor in the overall loss function.
name : str
Keyword of the quantity, e.g., 'energy'.
loss_type : str, optional
Weighting scheme for atomic contributions. See the MLIP package
for reference 10.1088/2632-2153/abc9fe for details, by default "mse".
weight : NonNegativeFloat, optional
Weighting factor in the overall loss function, by default 1.0.
atoms_exponent : NonNegativeFloat, optional
Exponent for atomic contributions weighting, by default 1.
parameters : dict, optional
Additional parameters for configuring the loss function, by default {}.
Notes
-----
This class specifies the configuration of the loss functions used during training.
"""

name: str
Expand Down

0 comments on commit 2fc94f4

Please sign in to comment.