Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update cluster list #135

Closed
wants to merge 100 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
695b61c
Add `distributed.checkpoint.load_state_dict()` function
epwalsh Dec 4, 2024
6135725
Move MoE callback logic to a new `MoEHandler` class
epwalsh Dec 4, 2024
cdf13c1
Add the Float8Handler to the Transformer model
epwalsh Dec 4, 2024
ce50477
Add `namespace` arg to `Trainer.record_metric()`
epwalsh Dec 4, 2024
3cdc667
Add `TrainModule` abstraction
epwalsh Dec 4, 2024
1a68888
fix
epwalsh Dec 4, 2024
24a1382
fix
epwalsh Dec 4, 2024
1473564
fix
epwalsh Dec 4, 2024
638b372
clean up
epwalsh Dec 4, 2024
cafd0d1
fix evaluator callback
epwalsh Dec 5, 2024
f8bdf84
train config
epwalsh Dec 5, 2024
8d33250
prepare for pipeline parallel
epwalsh Dec 5, 2024
d04dd66
more progress towards pipeline parallel
epwalsh Dec 5, 2024
a40248c
fix
epwalsh Dec 5, 2024
cd634a6
Merge branch 'main' into epwalsh/train-module
epwalsh Dec 9, 2024
e0d66c4
Merge branch 'epwalsh/train-module' into epwalsh/train-module2
epwalsh Dec 9, 2024
64a32fc
update docker image tags
epwalsh Dec 9, 2024
cdb1fa1
Merge branch 'main' into epwalsh/train-module
epwalsh Dec 9, 2024
2254dc3
fix merge conflicts
epwalsh Dec 9, 2024
9eee883
clean up
epwalsh Dec 9, 2024
f153cf9
Merge branch 'main' into epwalsh/train-module
epwalsh Dec 10, 2024
ca915e3
fix merge conflicts
epwalsh Dec 10, 2024
ecb7b67
Move more model configs stuff to TransformerTrainModule
epwalsh Dec 10, 2024
d568fc2
log callback order
epwalsh Dec 11, 2024
4dbdf80
increase priority of checkpointer callback
epwalsh Dec 11, 2024
ea3241c
also sort by order added
epwalsh Dec 11, 2024
87c81ce
load strict
epwalsh Dec 11, 2024
14ee57b
fix
epwalsh Dec 11, 2024
dbe6105
do not allow compiling fused loss
epwalsh Dec 11, 2024
bc9f2df
Merge branch 'main' into epwalsh/train-module
epwalsh Dec 11, 2024
acc8b4a
Merge branch 'epwalsh/train-module' into epwalsh/train-module2
epwalsh Dec 11, 2024
14c59a1
Finish implemented pipeline parallelism, I think
epwalsh Dec 11, 2024
6f77813
add another barrier
epwalsh Dec 11, 2024
a3f60fe
update some callback priorities
epwalsh Dec 11, 2024
e279fd8
update docs
epwalsh Dec 11, 2024
4fca439
fix building optimizer for PP
epwalsh Dec 12, 2024
f87158e
logging improvements
epwalsh Dec 12, 2024
5550a9b
log memory usage before training
epwalsh Dec 12, 2024
627796f
log data parallel world size
epwalsh Dec 12, 2024
442c084
reformat logging
epwalsh Dec 12, 2024
e4fb54c
clean up
epwalsh Dec 12, 2024
29a0dc9
broadcast loss
epwalsh Dec 13, 2024
1dbcb32
fix
epwalsh Dec 13, 2024
498ce26
clean up
epwalsh Dec 13, 2024
115a835
log more info about param group
epwalsh Dec 13, 2024
5b964c4
improve logging
epwalsh Dec 13, 2024
8ba1009
more logging
epwalsh Dec 13, 2024
cd0f4d8
more logging
epwalsh Dec 13, 2024
4bc329f
fix
epwalsh Dec 13, 2024
f432c5b
fix?
epwalsh Dec 13, 2024
d5345eb
fix
epwalsh Dec 13, 2024
795106a
flatten
epwalsh Dec 13, 2024
4e652a7
allow different opts for save/load
epwalsh Dec 13, 2024
9577a0c
clean up
epwalsh Dec 13, 2024
995c83e
add link to upgrade guide
epwalsh Dec 13, 2024
e0149ab
fix eval batch size guess
epwalsh Dec 13, 2024
1fd4262
use seperate schedule for eval
epwalsh Dec 13, 2024
b4931e2
fix
epwalsh Dec 13, 2024
6da7f29
fix
epwalsh Dec 13, 2024
7df1a22
fix
epwalsh Dec 13, 2024
f6eb6b6
fix?
epwalsh Dec 13, 2024
6fad112
fix?
epwalsh Dec 13, 2024
4cfed66
fix
epwalsh Dec 13, 2024
1fd78b1
reduce logs from google clients
epwalsh Dec 13, 2024
b942df1
fix
epwalsh Dec 14, 2024
c067af0
avoid broadcast
epwalsh Dec 14, 2024
9d77350
fix merge conflict
epwalsh Dec 16, 2024
fb3019b
Merge branch 'epwalsh/train-module' into epwalsh/train-module2
epwalsh Dec 16, 2024
bba1fe4
train module specifies eval batch size
epwalsh Dec 18, 2024
0bd6b50
Fix eval seq length when PP enabled
epwalsh Dec 18, 2024
41a055e
add a default downstream evaluator
epwalsh Dec 18, 2024
3dd3f7e
drop last incomplete batch
epwalsh Dec 18, 2024
c2d86f8
install debug branch
epwalsh Dec 18, 2024
dadea11
fix?
epwalsh Dec 18, 2024
42d0124
ha fix
epwalsh Dec 18, 2024
8153e81
asserts
epwalsh Dec 18, 2024
fb89bf2
print
epwalsh Dec 18, 2024
eb5f4fe
fix?
epwalsh Dec 18, 2024
616687a
fix
epwalsh Dec 19, 2024
9752ec7
fix
epwalsh Dec 19, 2024
77bfab8
clean up
epwalsh Dec 19, 2024
7435cb8
Merge branch 'main' into epwalsh/train-module
epwalsh Dec 19, 2024
54fc9be
Merge branch 'epwalsh/train-module' into epwalsh/train-module2
epwalsh Dec 19, 2024
c44a87a
fix merge conflict
epwalsh Dec 19, 2024
bf95b88
Merge branch 'epwalsh/train-module' into epwalsh/train-module2
epwalsh Dec 19, 2024
d94e6ee
Merge branch 'main' into epwalsh/train-module
epwalsh Dec 19, 2024
88bd158
Merge branch 'epwalsh/train-module' into epwalsh/train-module2
epwalsh Dec 19, 2024
e67727d
Merge branch 'main' into epwalsh/train-module
epwalsh Dec 20, 2024
da773ca
fix merge conflict
epwalsh Dec 20, 2024
ce83447
fix typo
epwalsh Dec 21, 2024
009119b
Merge branch 'main' into epwalsh/train-module
epwalsh Jan 8, 2025
8882a3c
Merge branch 'epwalsh/train-module' into epwalsh/train-module2
epwalsh Jan 8, 2025
533f016
clean up, fix total grad norm reporting
epwalsh Jan 8, 2025
a51907e
trigger workflows for v2 branch
epwalsh Jan 8, 2025
803d306
Merge branch 'main' into v2
epwalsh Jan 9, 2025
3a40345
Fix table formatting
epwalsh Jan 9, 2025
cc84cfb
clean up changelog
epwalsh Jan 9, 2025
9d0c622
support load key mapping, auto determine flat optim
epwalsh Jan 9, 2025
933300a
Merge branch 'main' into v2
epwalsh Jan 9, 2025
8ef02d7
update cluster list
epwalsh Jan 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ on:
pull_request:
branches:
- main
- v2
push:
branches:
- main
- v2
tags:
- 'v*.*.*'

Expand Down Expand Up @@ -185,9 +187,11 @@ jobs:
# H100 clusters
- ai2/jupiter-cirrascale-2
- ai2/augusta-google-1
- ai2/allennlp-elara-cirrascale
- ai2/ganymede-cirrascale
- ai2/ceres-cirrascale
# A100 clusters
- ai2/saturn-cirrascale
- ai2/allennlp-cirrascale
# - ai2/allennlp-elanding-a100-40g
envVars:
- name: CUBLAS_WORKSPACE_CONFIG
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pr_checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ on:
pull_request:
branches:
- main
- v2
paths:
- 'src/**'

Expand Down
20 changes: 20 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,26 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## v2

This major release introduces a few breaking changes. As such, we've provided an upgrade guide here: [OLMo-core upgrade guide](https://docs.google.com/document/d/1LvANhNzA-MdtiD2pLniLTqB9wxSSuqY435WuJIADeFM/edit?usp=sharing).

### Added

- Added `TrainModule` abstraction with `TransformerTrainModule` implementation, which encapsulates both a model and optimizer.
- Added `namespace` argument to `Trainer.record_metric()`.

### Changed

- The `Trainer` now takes a `TrainModule` instead of a model and optimizer, and several configuration options have been moved to `TransformerTrainModule`, including `rank_microbatch_size`, `fused_loss`, `compile_loss`, `z_loss_multiplier`, and `autocast_precision`.
- Several `TransformerModelConfig` options have been to `TransformerTrainModule` / `TransformerTrainModuleConfig`, including `dp_config`, `tp_config`, `float8_config`, and `compile`.

### Removed

- Removed the following callbacks: `MoEHandlerCallback`, `SchedulerCallback`, `MatrixNormalizerCallback`, `GradClipperCallback`, and `Float8HandlerCallback`.
The functionality from all of those callbacks has been moved to the `TransformerTrainModule` class.
- Removed the callback methods `.pre_eval_batch()` and `.post_eval_batch()`.

## Unreleased

### Added
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@ To see the exact usage for each script, run the script without any arguments.

Throughput numbers from these scripts with various different configuration settings are reported below, measured on a cluster with NVIDIA H100 GPUs.

| Model size | Model arch.   | Context length | Precision | Throughput[^1] | Training   script | Commandline overrides                                    |
| Model size | Model arch.   | Context length | Precision | Throughput[^1] | Training   script | Commandline overrides                                                   |
| :--------: | :--------: | :------------: | :-------: | -----------: | :----------- | :-------- |
| **1B** | OLMo-1124 | 4096 | BF16 | 55,000 TPS | `OLMo2-1B.py` | |
| | | 4096 | BF16/FP8[^2] | 65,000 TPS | `OLMo2-1B.py` | `--model.float8_config.enabled=true` |
| | | 4096 | BF16/FP8[^2] | 65,000 TPS | `OLMo2-1B.py` | `--train_module.float8_config.enabled=true` |
| **7B** | OLMo-1124 | 4096 | BF16 | 10,000 TPS | `OLMo2-7B.py` | |
| | | 4096 | BF16/FP8 | 13,000 TPS | `OLMo2-7B.py` | `--model.float8_config.enabled=true` |
| | | 4096 | BF16/FP8 | 13,000 TPS | `OLMo2-7B.py` | `--train_module.float8_config.enabled=true` |
| **8B** | Llama | 4096 | BF16 | 9,500 TPS | `Llama3-8B.py` | |
| | | 4096 | BF16/FP8 | 12,500 TPS | `Llama3-8B.py` | `--model.float8_config.enabled=true` |
| | | 4096 | BF16/FP8 | 12,500 TPS | `Llama3-8B.py` | `--train_module.float8_config.enabled=true` |
| **13B** | OLMo-1124 | 4096 | BF16 | 4,600 TPS | `OLMo2-13B.py` | |
| | | 4096 | BF16/FP8 | 5,500 TPS | `OLMo2-13B.py` | `--model.float8_config.enabled=true` |
| | | 4096 | BF16/FP8 | 5,500 TPS | `OLMo2-13B.py` | `--train_module.float8_config.enabled=true` |

[^1]: Throughput reported in tokens per second per device.
[^2]: In this setup most matrix multiplications are computed in `float8`, everything else is in `bfloat16`.
Expand Down
23 changes: 16 additions & 7 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,22 @@ def autodoc_skip_member(app, what, name, obj, skip, options):

module = inspect.getmodule(obj)
module_name = None if module is None else module.__name__
if (
what == "class"
and module_name is not None
and module_name.startswith("olmo_core.train.callbacks")
and module_name != "olmo_core.train.callbacks.callback"
):
if inspect.isfunction(obj) or inspect.ismethod(obj):

if what == "class" and module_name is not None:
# Skip documenting callback subclass methods.
if (
module_name.startswith("olmo_core.train.callbacks.")
and module_name != "olmo_core.train.callbacks.callback"
and (inspect.isfunction(obj) or inspect.ismethod(obj))
):
return True

# Skip documenting train module subclass methods.
if (
module_name.startswith("olmo_core.train.train_module.")
and module_name != "olmo_core.train.train_module.train_module"
and (inspect.isfunction(obj) or inspect.ismethod(obj) or isinstance(obj, property))
):
return True

return skip
Expand Down
1 change: 1 addition & 0 deletions docs/source/nn/transformer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@

.. automodule:: olmo_core.nn.transformer
:members:
:exclude-members: TransformerDataParallelWrappingStrategy,TransformerActivationCheckpointingMode
11 changes: 6 additions & 5 deletions docs/source/overview/introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@ Most users will likely follow a workflow that looks like this:
For example::

model_config = TransformerConfig.llama2_7B(...)
optim_config = AdamWConfig(lr=1e-3, ...)
train_module_config = TransformerTrainModuleConfig(...)
data_config = NumpyDatasetConfig(...)
data_loader_config = NumpyDataLoaderConfig(...)
trainer_config = TrainerConfig(...)

2. Build the corresponding components within a ``main()`` function at runtime and then call :meth:`Trainer.fit() <olmo_core.train.Trainer.fit>`.
For example::

def main(model_config, optim_config, data_config, trainer_config):
def main():
model = model_config.build()
optim = optim_config.build()
dataset = data_config.build()
trainer = trainer_config.build(model, optim, dataset)
train_module = train_module_config.build(model)
data_loader = data_loader_config.build(data_config.build(), dp_process_group=train_module.dp_process_groupo)
trainer = trainer_config.build(train_module, data_loader)

trainer.fit()

Expand Down
1 change: 1 addition & 0 deletions docs/source/train/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
:caption: Submodules

callbacks
train_module
5 changes: 5 additions & 0 deletions docs/source/train/train_module.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
``train.train_module``
======================

.. automodule:: olmo_core.train.train_module
:members:
4 changes: 3 additions & 1 deletion src/examples/huggingface/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ def validate_conversion(hf_model):

del hf_model

model = MODEL_CONFIG.build(device=device, max_seq_len=131072).eval()
model = MODEL_CONFIG.build()
model.init_weights(device=device, max_seq_len=131072)
model.eval()

log.info("Loading converted checkpoint for validation...")
load_model_and_optim_state(SAVE_PATH, model)
Expand Down
82 changes: 30 additions & 52 deletions src/examples/llama/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
TokenizerConfig,
)
from olmo_core.distributed.parallel import DataParallelType
from olmo_core.nn.transformer import TransformerConfig, TransformerDataParallelConfig
from olmo_core.nn.transformer import TransformerConfig
from olmo_core.optim import AdamWConfig, CosWithWarmup, OptimGroupOverride
from olmo_core.train import (
Duration,
Expand All @@ -32,22 +32,23 @@
ConfigSaverCallback,
DownstreamEvaluatorCallbackConfig,
GPUMemoryMonitorCallback,
GradClipperCallback,
LMEvaluatorCallbackConfig,
ProfilerCallback,
SchedulerCallback,
SequenceLengthSchedulerCallback,
WandBCallback,
)
from olmo_core.utils import get_default_device, seed_all
from olmo_core.train.train_module import (
TransformerDataParallelConfig,
TransformerTrainModuleConfig,
)
from olmo_core.utils import seed_all


@dataclass
class ExperimentConfig(Config):
model: TransformerConfig
optim: AdamWConfig
dataset: NumpyDatasetConfig
data_loader: NumpyDataLoaderConfig
train_module: TransformerTrainModuleConfig
trainer: TrainerConfig
init_seed: int = 12536

Expand All @@ -57,30 +58,13 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig:

model_config = TransformerConfig.llama2_271M(
vocab_size=tokenizer_config.padded_vocab_size(), # a little bigger than actual vocab size to make it a multiple of 128
compile=True,
fused_ops=False,
use_flash=False,
dp_config=TransformerDataParallelConfig(
name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32
),
)

optim_config = AdamWConfig(
lr=1e-3,
group_overrides=[
OptimGroupOverride(params=["embeddings.weight"], opts=dict(weight_decay=0.0))
],
)

dataset_config = NumpyDatasetConfig.glob(
"/net/nfs/allennlp/llm-data/c4/en/c4-train.*.npy", # can be globs
name=NumpyDatasetType.fsl,
sequence_length=1024,
max_target_sequence_length=8192,
# name=NumpyDatasetType.vsl,
# max_sequence_length=2048,
# min_sequence_length=256,
# vsl_curriculum=VSLCurriculumConfig(name=VSLCurriculumType.grow_p2, num_cycles=4),
tokenizer=tokenizer_config,
work_dir="/tmp/dataset-cache",
)
Expand All @@ -91,28 +75,32 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig:
num_workers=4,
)

train_module_config = TransformerTrainModuleConfig(
rank_microbatch_size=16 * 1024,
max_sequence_length=dataset_config.effective_sequence_length,
optim=AdamWConfig(
lr=1e-3,
group_overrides=[
OptimGroupOverride(params=["embeddings.weight"], opts=dict(weight_decay=0.0))
],
),
compile_model=True,
dp_config=TransformerDataParallelConfig(
name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32
),
compile_loss=True,
max_grad_norm=1.0,
scheduler=CosWithWarmup(warmup_steps=100),
)

trainer_config = (
TrainerConfig(
save_folder=f"/tmp/{run_name}",
rank_microbatch_size=16 * 1024,
save_overwrite=True,
metrics_collect_interval=5,
cancel_check_interval=5,
load_key_mapping={
# For backwards compatibility when loading older checkpoints.
"lm_head.w_out.weight": "w_out.weight",
"lm_head.norm.weight": "norm.weight",
},
)
.with_callback("lr_scheduler", SchedulerCallback(scheduler=CosWithWarmup(warmup_steps=100)))
.with_callback(
"seq_len_scheduler",
SequenceLengthSchedulerCallback(
min_sequence_length=128, warmup_steps=100, enabled=False
),
)
.with_callback("gpu_monitor", GPUMemoryMonitorCallback())
.with_callback("grad_clipper", GradClipperCallback(max_grad_norm=1.0))
.with_callback(
"checkpointer",
CheckpointerCallback(
Expand Down Expand Up @@ -166,9 +154,9 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig:

return ExperimentConfig(
model=model_config,
optim=optim_config,
dataset=dataset_config,
data_loader=data_loader_config,
train_module=train_module_config,
trainer=trainer_config,
).merge(overrides)

Expand All @@ -179,22 +167,12 @@ def main(run_name: str, overrides: List[str]):
# Set RNG states on all devices.
seed_all(config.init_seed)

device = get_default_device()

# Build the world mesh, if needed.
world_mesh = config.model.build_mesh(device=device)

# Build components.
model = config.model.build(
init_device="meta",
device=device,
max_seq_len=config.dataset.sequence_length,
mesh=world_mesh,
)
optim = config.optim.build(model)
model = config.model.build(init_device="meta")
train_module = config.train_module.build(model)
dataset = config.dataset.build()
data_loader = config.data_loader.build(dataset, mesh=world_mesh)
trainer = config.trainer.build(model, optim, data_loader, mesh=world_mesh)
data_loader = config.data_loader.build(dataset, dp_process_group=train_module.dp_process_group)
trainer = config.trainer.build(train_module, data_loader)

# Save config to W&B and each checkpoint dir.
config_dict = config.as_config_dict()
Expand Down
Loading
Loading