Component type | Component Version | Implementation | Configuration | Component Interface | Description |
---|---|---|---|---|---|
model | gpt2 | GPT2LLM | GPT2LLMConfig | NNModel | GPT2 model for language modeling |
model | huggingface_pretrained_model | HuggingFacePretrainedModel | HuggingFacePretrainedModelConfig | NNModel | HuggingFace pretrained model for language modeling |
model | checkpointed | ModelFactory.get_checkpointed_model | CheckpointedModelConfig | nn.Module | Checkpointed Model instance |
model | fsdp_wrapped | ModelFactory.get_fsdp_wrapped_model | FSDPWrappedModelConfig | NNModel | Model that has been sharded via FSDP |
model | model_initialized | ModelFactory.get_weight_initalized_model | WeightInitializedModelConfig | nn.Module | Model with initialized weights |
model | coca | CoCa | CoCaConfig | NNModel | CoCa Model (Contrastive Captioners) |
Component type | Component Version | Implementation | Configuration | Component Interface | Description |
---|---|---|---|---|---|
model_initialization | composed | ComposedInitializationRoutines.get_composed_model_initializer | ComposedModelInitializationConfig | ModelInitializationIF | Component for initializing model weights in place |
Component type | Component Version | Implementation | Configuration | Component Interface | Description |
---|---|---|---|---|---|
loss | clm_cross_entropy_loss | CLMCrossEntropyLoss | CLMCrossEntropyLossConfig | Loss | Cross-entropy loss function |
Component type | Component Version | Implementation | Configuration | Component Interface | Description |
---|---|---|---|---|---|
optimizer | adam | OptimizerFactory.get_adam | AdamOptimizerConfig | Optimizer | ADAM optimizer |
optimizer | adam_w | OptimizerFactory.get_adam_w | AdamWOptimizerConfig | Optimizer | ADAMW Optimizer |
optimizer | checkpointed | OptimizerFactory.get_checkpointed_optimizer | CheckpointedOptimizerConfig | Optimizer | Optimizer instantiated from checkpoint |
Component type | Component Version | Implementation | Configuration | Component Interface | Description |
---|---|---|---|---|---|
scheduler | dummy_lr | DummyLRScheduler | DummyLRSchedulerConfig | LRScheduler | Fake lr scheduler not adapting the lr rate |
scheduler | step_lr | StepLR | StepLRSchedulerConfig | LRScheduler | Decays the learning rate of each parameter group by gamma every step_size steps |
scheduler | constant_lr | ConstantLR | ConstantLRSchedulerConfig | LRScheduler | Multiplies the learning rate of each parameter group by a small constant factor until the number of steps reaches a pre-defined milestone |
scheduler | onecycle_lr | OneCycleLR | OneCycleLRSchedulerConfig | LRScheduler | Sets the learning rate of each parameter group according to the 1cycle learning rate policy. |
scheduler | cosine_annealing_lr | CosineAnnealingLR | CosineAnnealingLRSchedulerConfig | LRScheduler | Set the learning rate of each parameter group using a cosine annealing schedule |
Component type | Component Version | Implementation | Configuration | Component Interface | Description |
---|---|---|---|---|---|
tokenizer | pretrained_hf_tokenizer | PreTrainedHFTokenizer | PreTrainedHFTokenizerConfig | TokenizerWrapper | Pretrained Huggingface tokenizer |
tokenizer | pretrained_sp_tokenizer | PreTrainedSPTokenizer | PreTrainedSPTokenizerConfig | TokenizerWrapper | Pretrained SentencePiece tokenizer |
Component type | Component Version | Implementation | Configuration | Component Interface | Description |
---|---|---|---|---|---|
dataset | mem_map_dataset | DatasetFactory.get_mem_map_dataset | MemMapDatasetConfig | Dataset | MemMap Dataset |
dataset | packed_mem_map_dataset_continuous | DatasetFactory.get_packed_mem_map_dataset_continuous | PackedMemMapDatasetContinuousConfig | Dataset | Packed Memory Mapped Dataset Continuous |
dataset | dummy_dataset | DatasetFactory.get_dummy_dataset | DummyDatasetConfig | Dataset | Dummy dataset creating random samples of specified shape |
dataset | combined | DatasetFactory.get_combined_dataset | CombinedDatasetConfig | Dataset | Dataset implementation combining multiple datasets into one. |
Component type | Component Version | Implementation | Configuration | Component Interface | Description |
---|---|---|---|---|---|
sampler | distributed_sampler | DistributedSampler | DistributedSamplerConfig | Sampler | Sampler that restricts data loading to a subset of the dataset for distributed training |
batch_sampler | default | BatchSampler | BatchSamplerConfig | Sampler | Wraps another sampler to yield a mini-batch of indices. |
Component type | Component Version | Implementation | Configuration | Component Interface | Description |
---|---|---|---|---|---|
collate_fn | gpt_2_llm_collator | GPT2LLMCollateFn | GPT2LLMCollateFnConfig | CollateFnIF | Data collator for the GPT2 model |
collate_fn | coca_collator | CoCaCollatorFn | CoCaCollateFnConfig | CollateFnIF | Data collator for the CoCa model |
Component type | Component Version | Implementation | Configuration | Component Interface | Description |
---|---|---|---|---|---|
data_loader | default | DataloaderFactory.get_dataloader | LLMDataLoaderConfig | DataLoader | LLM Data loader extending pytorch data loader functionality |
Component type | Component Version | Implementation | Configuration | Component Interface | Description |
---|---|---|---|---|---|
checkpoint_saving | default | CheckpointSaving | CheckpointSavingConfig | -- | Component for saving checkpoints based on a savig and execution strategy. |
checkpoint_saving_strategy | save_every_k_steps_checkpointing_strategy | SaveEveryKStepsCheckpointingStrategy | SaveEveryKStepsCheckpointingStrategyConfig | CheckpointSavingStrategyIF | Checkpointing strategy saving a checkpoint every k steps |
checkpoint_saving_strategy | save_k_most_recent_checkpoints_strategy | SaveKMostRecentCheckpointsStrategy | SaveKMostRecentCheckpointsStrategyConfig | CheckpointSavingStrategyIF | Checkpointing strategy saving only the last k checkpoints and deleting the previous ones |
checkpoint_saving_execution | fsdp | FSDPCheckpointSaving | FSDPCheckpointSavingConfig | CheckpointSavingExecutionABC | FSDPCheckpointSaving class for saving checkpoints of FSDP models and optimizers. |
checkpoint_loading | fsdp | FSDPCheckpointLoading | FSDPCheckpointLoadingConfig | CheckpointLoadingIF | Component for loading FSDP checkpoints |
checkpoint_loading | torch | TorchCheckpointLoading | TorchCheckpointLoadingConfig | CheckpointLoadingIF | Component for loading PyTorch checkpoints |
Component type | Component Version | Implementation | Configuration | Component Interface | Description |
---|---|---|---|---|---|
progress_subscriber | dummy | ProgressSubscriberFactory.get_dummy_progress_subscriber | DummyProgressSubscriberConfig | MessageSubscriberIF | Dummy Progress subscriber not consuming any messages |
progress_subscriber | rich | ProgressSubscriberFactory.get_rich_progress_subscriber | RichProgressSubscriberConfig | MessageSubscriberIF | Subscriber for writing out rich-formatted console outputs w.r.t. to training and evaluation progress |
results_subscriber | wandb | ProgressSubscriberFactory.get_wandb_result_subscriber | WandBEvaluationResultSubscriberConfig | MessageSubscriberIF | Subscriber for logging evaluation results to Weights and Biases |
Component type | Component Version | Implementation | Configuration | Component Interface | Description |
---|---|---|---|---|---|
layer_norm | rms_norm | RMSLayerNorm | RMSLayerNormConfig | nn.Module | RMS Layer norm |
layer_norm | layer_norm | nn.LayerNorm | LayerNormConfig | nn.Module | Layer norm |
Component type | Component Version | Implementation | Configuration | Component Interface | Description |
---|---|---|---|---|---|
gradient_clipper | fsdp | FSDPGradientClipper | FSDPGradientClipperConfig | GradientClipperIF | FSDP Gradient Clipper |
gradient_clipper | fsdp_logging_only | FSDPLoggingOnlyGradientClipper | FSDPGradientClipperConfig | GradientClipperIF | Clipper that is responsible for logging the gradient norms without actually clipping the gradients |
gradient_clipper | dummy | DummyGradientClipper | DummyGradientClipperConfig | GradientClipperIF | Dummy clipper that does not apply any gradient clipping. |
Component type | Component Version | Implementation | Configuration | Component Interface | Description |
---|---|---|---|---|---|
number_conversion | local_num_batches_from_num_samples | NumberConversion.get_local_num_batches_from_num_samples | LocalNumBatchesFromNumSamplesConfig | -- | Calculates the number of local batches for each rank, given the global number of samples and number of ranks. |
number_conversion | local_num_batches_from_num_tokens | NumberConversion.get_local_num_batches_from_num_tokens | LocalNumBatchesFromNumTokensConfig | -- | Calculates the number of local batches for each rank, given the global number of tokens and number of ranks. |
number_conversion | local_num_batches_from_num_tokens | NumberConversion.get_num_samples_from_num_tokens | NumSamplesFromNumTokensConfig | -- | Calculates the number of global samples, given the global number of tokens and sequence length |
number_conversion | num_steps_from_num_samples | NumberConversion.get_num_steps_from_num_samples | NumStepsFromNumSamplesConfig | -- | Calculates the number of steps given the global number of samples, local micro batch size and number of ranks. |
number_conversion | num_steps_from_num_tokens | NumberConversion.get_num_steps_from_num_tokens | NumStepsFromNumTokensConfig | -- | Calculates the number of steps given the global number of tokens, local micro batch size and number of ranks. |
number_conversion | num_tokens_from_num_steps | NumberConversion.get_num_tokens_from_num_steps | NumTokensFromNumStepsConfig | -- | Calculates the number of tokens from the number of steps, number of ranks, local micro batch size, global number of tokens, squence length and gradient accumulation steps |
number_conversion | last_step_from_checkpoint_path | NumberConversion.get_num_seen_steps_from_checkpoint_path | NumberConversionFromCheckpointPathConfig | -- | Get the last step id from a model or checkpoint file path. |
number_conversion | global_num_target_tokens_from_checkpoint_path | NumberConversion.get_global_num_target_tokens_from_checkpoint_path | NumberConversionFromCheckpointPathConfig | -- | Get the number of target tokens from a model or checkpoint file path. |
number_conversion | num_tokens_from_packed_mem_map_dataset_continuous | NumberConversion.get_num_tokens_from_packed_mem_map_dataset_continuous | NumTokensFromPackedMemMapDatasetContinuousConfig | -- | Get the number of tokens stored in a packed mem map continuous dataset from the respective dataset file path. |
number_conversion | num_steps_from_raw_dataset_index | NumberConversion.get_num_steps_from_raw_dataset_index | NumStepsFromRawDatasetIndexConfig | -- | Get the number of steps partially from the raw index of a raw JSONL dataset. Requires the file path to index, number of ranks, local micro batch size and gardient accumulation steps. |