Skip to content

Commit

Permalink
Log active params
Browse files Browse the repository at this point in the history
  • Loading branch information
Muennighoff committed Apr 16, 2024
1 parent 761d36a commit f898fd9
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
10 changes: 9 additions & 1 deletion olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1505,7 +1505,7 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
else:
raise NotImplementedError(wrap_strategy)

def num_params(self, include_embedding: bool = True) -> int:
def num_params(self, include_embedding: bool = True, include_inactivated_experts: bool = True) -> int:
"""
Get the total number of parameters.
"""
Expand All @@ -1515,6 +1515,14 @@ def num_params(self, include_embedding: bool = True) -> int:
lambda np: ".wte." not in np[0] and ".wpe." not in np[0],
params,
)
if not include_inactivated_experts:
# Need to reduce blocks the number of experts that are selected
# e.g. 'transformer.blocks.0.ffn.experts.mlp.w1' has shape (total_experts, in_dim, out_dim)
# change to 'transformer.blocks.0.ffn.experts.mlp.w1' has shape (selected_experts, in_dim, out_dim)
params = [
(np[0], np[1][: self.config.moe_top_k] if "experts.mlp" in np[0] else np[1])
for np in params
]
return sum(p.numel() for _, p in params)

@property
Expand Down
1 change: 0 additions & 1 deletion olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ def fused_loss_fn(

self.loss_fn = fused_loss_fn

print(self.cfg)
if self.model.config.block_type == BlockType.moe:
# these MoEArgs are necessary for logging load balancing.
self.moe_args = MoEArgs(
Expand Down
2 changes: 2 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def main(cfg: TrainConfig) -> None:
olmo_model = OLMo(cfg.model)
log.info(f"Total number of parameters: {olmo_model.num_params():,d}")
log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embedding=False):,d}")
if olmo_model.config.block_type == "moe":
log.info(f"Number of active parameters: {olmo_model.num_params(include_inactivated_experts=False):,d}")
log.info(f"Peak GPU Memory (MB) before FSDP: {int(peak_gpu_memory() or 0)}")

olmo_model.set_activation_checkpointing(cfg.activation_checkpointing)
Expand Down

0 comments on commit f898fd9

Please sign in to comment.