diff --git a/olmo/config.py b/olmo/config.py index 17c463f04..5dc44cfcf 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -669,11 +669,29 @@ class ShardedCheckpointerType(StrEnum): class ActivationCheckpointingStrategy(StrEnum): whole_layer = "whole_layer" + """ + Checkpoint every transformer layer. + """ + one_in_two = "one_in_two" + """ + Checkpoint one in two transformer layers. + """ + one_in_three = "one_in_three" + """ + Checkpoint one in three transformer layers. + """ + one_in_four = "one_in_four" + """ + Checkpoint one in four transformer layers. + """ + fine_grained = "fine_grained" - + """ + Focus checkpointing on where it is cheap to recompute and saves most memory. + """ @dataclass class TrainConfig(BaseConfig):