Skip to content

Commit

Permalink
Preliminary Megablocks
Browse files Browse the repository at this point in the history
  • Loading branch information
Muennighoff committed Apr 9, 2024
1 parent 62c7954 commit f907988
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 4 deletions.
2 changes: 2 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ class BlockType(StrEnum):
sequential = "sequential"

llama = "llama"

moe = "moe"
"""
A block similar to the sequential block with slightly different
implementations of operations like attention to imitate the behavior of Llama.
Expand Down
160 changes: 160 additions & 0 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@

log = logging.getLogger(__name__)

try:
from megablocks.layers.moe import MoE
from megablocks.layers.arguments import Arguments as MoEArgs
except ImportError:
log.warning("megablocks not installed, MoE layers will not be available.")


def activation_checkpoint_function(cfg: ModelConfig):
preserve_rng_state = (
Expand Down Expand Up @@ -626,10 +632,164 @@ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OLMoBl
return OLMoSequentialBlock(layer_id, config, cache)
elif config.block_type == BlockType.llama:
return OLMoLlamaBlock(layer_id, config, cache)
elif config.block_type == BlockType.moe:
return OLMoEBlock(layer_id, config, cache)
else:
raise NotImplementedError(f"Unknown block type: '{config.block_type}'")


class OLMoEBlock(OLMoBlock):
"""
This is a a transformer MoE block where the output is computed as ``MoE(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
super().__init__()
self.layer_id = layer_id
self.config = config
self.hidden_size = (
config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
)
self.__cache = cache
assert config.d_model % config.n_heads == 0

self._activation_checkpoint_fn = None

# Dropout.
self.dropout = Dropout(config.residual_dropout)

# Layer norms.
self.k_norm: Optional[LayerNormBase] = None
self.q_norm: Optional[LayerNormBase] = None
if config.attention_layer_norm:
assert config.effective_n_kv_heads is not None
self.k_norm = LayerNormBase.build(
config,
size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
elementwise_affine=config.attention_layer_norm_with_affine,
)
self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)

# Make sure QKV clip coefficient is positive, otherwise it's not well-defined.
if config.clip_qkv is not None:
assert config.clip_qkv > 0

# Activation function.
self.act = Activation.build(config)
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0

# Attention output projection.
self.attn_out = nn.Linear(
config.d_model, config.d_model, bias=config.include_bias, device=config.init_device
)

# MoE Block
moe_args = MoEArgs(
hidden_size=config.d_model,
ffn_hidden_size=config.d_model*4,#int(self.act.output_multiplier * self.hidden_size),
moe_num_experts=8,#config.moe_num_experts,
moe_weight_parallelism=False,#config.moe_weight_parallelism,
moe_expert_model_parallelism=False,#config.moe_expert_model_parallelism,
moe_top_k=2,#config.moe_top_k,
moe_capacity_factor=1.25,#config.moe_capacity_factor,
moe_loss_weight=0.1,#config.moe_loss_weight,
device=torch.cuda.current_device(),
# Handled by FSDP
bf16=False,
fp16=False,
)
self.ffn = MoE(moe_args)

# Rotary embeddings.
if self.config.rope:
self.rotary_emb = RotaryEmbedding(config, self.__cache)

self.flash_attn_func = None
if config.flash_attention:
try:
from flash_attn import flash_attn_func # type: ignore

self.flash_attn_func = flash_attn_func
except ModuleNotFoundError:
pass

def reset_parameters(self):
if self.k_norm is not None:
self.k_norm.reset_parameters()
if self.q_norm is not None:
self.q_norm.reset_parameters()
init_weights(
self.config,
self.attn_out,
d=self.config.d_model,
layer_id=self.layer_id,
type_of_module=ModuleType.out_module,
)
self.attn_norm.reset_parameters()
self.ff_norm.reset_parameters()
# NOTE: the standard deviation for these weights does not depend on the layer.
init_weights(
self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
)
init_weights(
self.config, self.ffn, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
)

def forward(
self,
x: torch.Tensor,
attention_bias: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Get query, key, value projections.
# shape:
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
# - for multi-query attn q: (batch_size, seq_len, d_model)
# k, v: (batch_size, seq_len, d_model // n_heads)
# - for group query attn q: (batch_size, seq_len, d_model)
# k, v: (batch_size, seq_len, d_model // n_kv_heads)
if self._activation_checkpoint_fn is not None:
qkv = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x))
else:
qkv = self.att_proj(self.attn_norm(x))

if self.config.clip_qkv is not None:
qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)

q, k, v = qkv.split(self.fused_dims, dim=-1)

# Get attention scores.
if self._activation_checkpoint_fn is not None:
att, cache = self._activation_checkpoint_fn( # type: ignore
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
)
else:
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)

# Add attention scores.
# shape: (B, T, C)
x = x + self.dropout(att)

# Add feed-forward projection.
# shape: (batch_size, seq_len, d_model)
og_x = x
if self._activation_checkpoint_fn is not None:
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
else:
x = self.ff_norm(x)

if self._activation_checkpoint_fn is not None:
x, _ = self._activation_checkpoint_fn(self.ffn, x) # type: ignore
else:
x, _ = self.ffn(x)

x = self.dropout(x)
x = og_x + x

return x, cache


class OLMoSequentialBlock(OLMoBlock):
"""
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
Expand Down
44 changes: 40 additions & 4 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .aliases import PathOrStr
from .checkpoint import Checkpointer, FullCheckpointer, build_sharded_checkpointer
from .config import (
BlockType,
CheckpointType,
SchedulerUnits,
ShardedCheckpointerType,
Expand Down Expand Up @@ -54,6 +55,12 @@

log = logging.getLogger(__name__)

try:
from megablocks.layers.moe import batched_load_balancing_loss, clear_load_balancing_loss
from megablocks.layers.arguments import Arguments as MoEArgs
except ImportError:
log.warning(f"Megablocks not installed. To train MoE, install with pip install megablocks.")


@dataclass
class SpeedMonitor:
Expand Down Expand Up @@ -186,6 +193,22 @@ def fused_loss_fn(

self.loss_fn = fused_loss_fn

if self.cfg.block_type == BlockType.moe:#self.cfg.moe_freq > 0:
# these MoEArgs are necessary for logging load balancing.
self.moe_args = MoEArgs(
hidden_size=self.cfg.d_model,
ffn_hidden_size=self.cfg.d_model * 4,
moe_num_experts=8,#self.cfg.moe_num_experts,
num_layers=self.cfg.n_layers,#if params.moe_freq > 0 and layer_id % params.moe_freq == 0:
moe_expert_model_parallelism=True,
moe_top_k=2,#self.cfg.moe_top_k,
device=torch.cuda.current_device(),
moe_capacity_factor=1.25,#self.cfg.moe_capacity_factor,
moe_loss_weight=0.1,#self.cfg.moe_loss_weight,
fp16=False,
bf16=False,
)

@property
def dataset(self) -> IterableDataset:
assert isinstance(self.train_loader.dataset, IterableDataset)
Expand Down Expand Up @@ -643,6 +666,7 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor

ce_batch_loss = torch.tensor(0.0, device=self.device)
z_batch_loss = None if not self.cfg.softmax_auxiliary_loss else torch.tensor(0.0, device=self.device)
lb_batch_loss = None if self.cfg.block_type != BlockType.moe else torch.tensor(0.0, device=self.device)
for micro_batch in micro_batches:
with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
# Run forward pass.
Expand All @@ -669,12 +693,17 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor
else:
loss = ce_loss

if self.cfg.block_type == BlockType.moe:
lb_batch_loss = batched_load_balancing_loss(self.moe_args)
clear_load_balancing_loss()
loss += lb_batch_loss

del logits

# Run backward pass.
loss.backward()

return ce_batch_loss, z_batch_loss
return ce_batch_loss, z_batch_loss, lb_batch_loss

def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> Dict[str, float]:
metrics: Dict[str, float] = {}
Expand All @@ -691,7 +720,7 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) ->
batch = move_to_device(batch, self.device)

# Run forward-backward pass.
ce_batch_loss, z_batch_loss = self.train_batch(batch)
ce_batch_loss, z_batch_loss, lb_batch_loss = self.train_batch(batch)

# Collect loss, potentially reducing over all ranks.
if reduce_global_loss:
Expand All @@ -700,6 +729,9 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) ->
if z_batch_loss is not None:
dist.reduce(z_batch_loss, 0)
z_batch_loss.div_(get_world_size())
if lb_batch_loss is not None:
dist.reduce(lb_batch_loss, 0)
lb_batch_loss.div_(get_world_size())

# Clip gradient norms and collect param/gradient/optim metrics.
should_log_optim_metrics_this_step = self.should_log_optim_metrics_this_step()
Expand Down Expand Up @@ -728,9 +760,11 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) ->
# Collect metrics and check for NaN loss.
# NOTE: this involves a bunch of host-device syncs so we wait until the last moment to do this.
if torch.isnan(ce_batch_loss):
raise ValueError("nan loss encountered")
raise ValueError("nan ce loss encountered")
if z_batch_loss is not None and torch.isnan(z_batch_loss):
raise ValueError("nan loss encountered")
raise ValueError("nan z loss encountered")
if lb_batch_loss is not None and torch.isnan(lb_batch_loss):
raise ValueError("nan lb loss encountered")
for key, value in optim_metrics.items():
metrics[f"optim/{key}"] = value.item()
self.cur_train_loss = ce_batch_loss.item()
Expand All @@ -739,6 +773,8 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) ->
metrics["train/Perplexity"] = math.exp(self.cur_train_loss)
if z_batch_loss is not None:
metrics["train/ZLoss"] = z_batch_loss.item()
if lb_batch_loss is not None:
metrics["train/LoadBalancingLoss"] = lb_batch_loss.item()

# Maybe collect post-step optimizer-specific metrics.
if should_log_optim_metrics_this_step:
Expand Down

0 comments on commit f907988

Please sign in to comment.