From f898fd9a02be65521aaf31d9dade49c2d5a3907a Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Tue, 16 Apr 2024 14:04:53 +0000 Subject: [PATCH] Log active params --- olmo/model.py | 10 +++++++++- olmo/train.py | 1 - scripts/train.py | 2 ++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/olmo/model.py b/olmo/model.py index b694fd87b..ba17252b0 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -1505,7 +1505,7 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): else: raise NotImplementedError(wrap_strategy) - def num_params(self, include_embedding: bool = True) -> int: + def num_params(self, include_embedding: bool = True, include_inactivated_experts: bool = True) -> int: """ Get the total number of parameters. """ @@ -1515,6 +1515,14 @@ def num_params(self, include_embedding: bool = True) -> int: lambda np: ".wte." not in np[0] and ".wpe." not in np[0], params, ) + if not include_inactivated_experts: + # Need to reduce blocks the number of experts that are selected + # e.g. 'transformer.blocks.0.ffn.experts.mlp.w1' has shape (total_experts, in_dim, out_dim) + # change to 'transformer.blocks.0.ffn.experts.mlp.w1' has shape (selected_experts, in_dim, out_dim) + params = [ + (np[0], np[1][: self.config.moe_top_k] if "experts.mlp" in np[0] else np[1]) + for np in params + ] return sum(p.numel() for _, p in params) @property diff --git a/olmo/train.py b/olmo/train.py index 60376eca8..6ab5377dc 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -193,7 +193,6 @@ def fused_loss_fn( self.loss_fn = fused_loss_fn - print(self.cfg) if self.model.config.block_type == BlockType.moe: # these MoEArgs are necessary for logging load balancing. self.moe_args = MoEArgs( diff --git a/scripts/train.py b/scripts/train.py index f93734c0b..67444fc31 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -117,6 +117,8 @@ def main(cfg: TrainConfig) -> None: olmo_model = OLMo(cfg.model) log.info(f"Total number of parameters: {olmo_model.num_params():,d}") log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embedding=False):,d}") + if olmo_model.config.block_type == "moe": + log.info(f"Number of active parameters: {olmo_model.num_params(include_inactivated_experts=False):,d}") log.info(f"Peak GPU Memory (MB) before FSDP: {int(peak_gpu_memory() or 0)}") olmo_model.set_activation_checkpointing(cfg.activation_checkpointing)