diff --git a/olmo/config.py b/olmo/config.py index 042c704ce..bbc0000b7 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -181,6 +181,8 @@ class BlockType(StrEnum): sequential = "sequential" llama = "llama" + + moe = "moe" """ A block similar to the sequential block with slightly different implementations of operations like attention to imitate the behavior of Llama. diff --git a/olmo/model.py b/olmo/model.py index 555e0ca81..e6cd7df38 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -72,6 +72,12 @@ log = logging.getLogger(__name__) +try: + from megablocks.layers.moe import MoE + from megablocks.layers.arguments import Arguments as MoEArgs +except ImportError: + log.warning("megablocks not installed, MoE layers will not be available.") + def activation_checkpoint_function(cfg: ModelConfig): preserve_rng_state = ( @@ -626,10 +632,164 @@ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OLMoBl return OLMoSequentialBlock(layer_id, config, cache) elif config.block_type == BlockType.llama: return OLMoLlamaBlock(layer_id, config, cache) + elif config.block_type == BlockType.moe: + return OLMoEBlock(layer_id, config, cache) else: raise NotImplementedError(f"Unknown block type: '{config.block_type}'") +class OLMoEBlock(OLMoBlock): + """ + This is a a transformer MoE block where the output is computed as ``MoE(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): + super().__init__() + self.layer_id = layer_id + self.config = config + self.hidden_size = ( + config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model + ) + self.__cache = cache + assert config.d_model % config.n_heads == 0 + + self._activation_checkpoint_fn = None + + # Dropout. + self.dropout = Dropout(config.residual_dropout) + + # Layer norms. + self.k_norm: Optional[LayerNormBase] = None + self.q_norm: Optional[LayerNormBase] = None + if config.attention_layer_norm: + assert config.effective_n_kv_heads is not None + self.k_norm = LayerNormBase.build( + config, + size=(config.d_model // config.n_heads) * config.effective_n_kv_heads, + elementwise_affine=config.attention_layer_norm_with_affine, + ) + self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine) + + # Make sure QKV clip coefficient is positive, otherwise it's not well-defined. + if config.clip_qkv is not None: + assert config.clip_qkv > 0 + + # Activation function. + self.act = Activation.build(config) + assert (self.act.output_multiplier * self.hidden_size) % 1 == 0 + + # Attention output projection. + self.attn_out = nn.Linear( + config.d_model, config.d_model, bias=config.include_bias, device=config.init_device + ) + + # MoE Block + moe_args = MoEArgs( + hidden_size=config.d_model, + ffn_hidden_size=config.d_model*4,#int(self.act.output_multiplier * self.hidden_size), + moe_num_experts=8,#config.moe_num_experts, + moe_weight_parallelism=False,#config.moe_weight_parallelism, + moe_expert_model_parallelism=False,#config.moe_expert_model_parallelism, + moe_top_k=2,#config.moe_top_k, + moe_capacity_factor=1.25,#config.moe_capacity_factor, + moe_loss_weight=0.1,#config.moe_loss_weight, + device=torch.cuda.current_device(), + # Handled by FSDP + bf16=False, + fp16=False, + ) + self.ffn = MoE(moe_args) + + # Rotary embeddings. + if self.config.rope: + self.rotary_emb = RotaryEmbedding(config, self.__cache) + + self.flash_attn_func = None + if config.flash_attention: + try: + from flash_attn import flash_attn_func # type: ignore + + self.flash_attn_func = flash_attn_func + except ModuleNotFoundError: + pass + + def reset_parameters(self): + if self.k_norm is not None: + self.k_norm.reset_parameters() + if self.q_norm is not None: + self.q_norm.reset_parameters() + init_weights( + self.config, + self.attn_out, + d=self.config.d_model, + layer_id=self.layer_id, + type_of_module=ModuleType.out_module, + ) + self.attn_norm.reset_parameters() + self.ff_norm.reset_parameters() + # NOTE: the standard deviation for these weights does not depend on the layer. + init_weights( + self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module + ) + init_weights( + self.config, self.ffn, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module + ) + + def forward( + self, + x: torch.Tensor, + attention_bias: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Get query, key, value projections. + # shape: + # - for regular attn q, k, v: (batch_size, seq_len, d_model) + # - for multi-query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_heads) + # - for group query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_kv_heads) + if self._activation_checkpoint_fn is not None: + qkv = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)) + else: + qkv = self.att_proj(self.attn_norm(x)) + + if self.config.clip_qkv is not None: + qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + q, k, v = qkv.split(self.fused_dims, dim=-1) + + # Get attention scores. + if self._activation_checkpoint_fn is not None: + att, cache = self._activation_checkpoint_fn( # type: ignore + self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache + ) + else: + att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache) + + # Add attention scores. + # shape: (B, T, C) + x = x + self.dropout(att) + + # Add feed-forward projection. + # shape: (batch_size, seq_len, d_model) + og_x = x + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore + else: + x = self.ff_norm(x) + + if self._activation_checkpoint_fn is not None: + x, _ = self._activation_checkpoint_fn(self.ffn, x) # type: ignore + else: + x, _ = self.ffn(x) + + x = self.dropout(x) + x = og_x + x + + return x, cache + + class OLMoSequentialBlock(OLMoBlock): """ This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` diff --git a/olmo/train.py b/olmo/train.py index 71a45312e..3e1623bd0 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -26,6 +26,7 @@ from .aliases import PathOrStr from .checkpoint import Checkpointer, FullCheckpointer, build_sharded_checkpointer from .config import ( + BlockType, CheckpointType, SchedulerUnits, ShardedCheckpointerType, @@ -54,6 +55,12 @@ log = logging.getLogger(__name__) +try: + from megablocks.layers.moe import batched_load_balancing_loss, clear_load_balancing_loss + from megablocks.layers.arguments import Arguments as MoEArgs +except ImportError: + log.warning(f"Megablocks not installed. To train MoE, install with pip install megablocks.") + @dataclass class SpeedMonitor: @@ -186,6 +193,22 @@ def fused_loss_fn( self.loss_fn = fused_loss_fn + if self.cfg.block_type == BlockType.moe:#self.cfg.moe_freq > 0: + # these MoEArgs are necessary for logging load balancing. + self.moe_args = MoEArgs( + hidden_size=self.cfg.d_model, + ffn_hidden_size=self.cfg.d_model * 4, + moe_num_experts=8,#self.cfg.moe_num_experts, + num_layers=self.cfg.n_layers,#if params.moe_freq > 0 and layer_id % params.moe_freq == 0: + moe_expert_model_parallelism=True, + moe_top_k=2,#self.cfg.moe_top_k, + device=torch.cuda.current_device(), + moe_capacity_factor=1.25,#self.cfg.moe_capacity_factor, + moe_loss_weight=0.1,#self.cfg.moe_loss_weight, + fp16=False, + bf16=False, + ) + @property def dataset(self) -> IterableDataset: assert isinstance(self.train_loader.dataset, IterableDataset) @@ -643,6 +666,7 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor ce_batch_loss = torch.tensor(0.0, device=self.device) z_batch_loss = None if not self.cfg.softmax_auxiliary_loss else torch.tensor(0.0, device=self.device) + lb_batch_loss = None if self.cfg.block_type != BlockType.moe else torch.tensor(0.0, device=self.device) for micro_batch in micro_batches: with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision): # Run forward pass. @@ -669,12 +693,17 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor else: loss = ce_loss + if self.cfg.block_type == BlockType.moe: + lb_batch_loss = batched_load_balancing_loss(self.moe_args) + clear_load_balancing_loss() + loss += lb_batch_loss + del logits # Run backward pass. loss.backward() - return ce_batch_loss, z_batch_loss + return ce_batch_loss, z_batch_loss, lb_batch_loss def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> Dict[str, float]: metrics: Dict[str, float] = {} @@ -691,7 +720,7 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> batch = move_to_device(batch, self.device) # Run forward-backward pass. - ce_batch_loss, z_batch_loss = self.train_batch(batch) + ce_batch_loss, z_batch_loss, lb_batch_loss = self.train_batch(batch) # Collect loss, potentially reducing over all ranks. if reduce_global_loss: @@ -700,6 +729,9 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> if z_batch_loss is not None: dist.reduce(z_batch_loss, 0) z_batch_loss.div_(get_world_size()) + if lb_batch_loss is not None: + dist.reduce(lb_batch_loss, 0) + lb_batch_loss.div_(get_world_size()) # Clip gradient norms and collect param/gradient/optim metrics. should_log_optim_metrics_this_step = self.should_log_optim_metrics_this_step() @@ -728,9 +760,11 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> # Collect metrics and check for NaN loss. # NOTE: this involves a bunch of host-device syncs so we wait until the last moment to do this. if torch.isnan(ce_batch_loss): - raise ValueError("nan loss encountered") + raise ValueError("nan ce loss encountered") if z_batch_loss is not None and torch.isnan(z_batch_loss): - raise ValueError("nan loss encountered") + raise ValueError("nan z loss encountered") + if lb_batch_loss is not None and torch.isnan(lb_batch_loss): + raise ValueError("nan lb loss encountered") for key, value in optim_metrics.items(): metrics[f"optim/{key}"] = value.item() self.cur_train_loss = ce_batch_loss.item() @@ -739,6 +773,8 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> metrics["train/Perplexity"] = math.exp(self.cur_train_loss) if z_batch_loss is not None: metrics["train/ZLoss"] = z_batch_loss.item() + if lb_batch_loss is not None: + metrics["train/LoadBalancingLoss"] = lb_batch_loss.item() # Maybe collect post-step optimizer-specific metrics. if should_log_optim_metrics_this_step: