Skip to content

Commit

Permalink
Make MoE configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
Muennighoff committed Apr 16, 2024
1 parent f907988 commit 761d36a
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 46 deletions.
20 changes: 20 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,26 @@ class ModelConfig(BaseConfig):
See :data:`TrainConfig.precision` instead.
"""

moe_num_experts: Optional[int] = 8
"""
The number of experts to use in the MoE block.
"""

moe_top_k: Optional[int] = 2
"""
The number of top experts to use in the MoE block.
"""

moe_capacity_factor: Optional[float] = 1.25
"""
The capacity factor to use in the MoE block.
"""

moe_loss_weight: Optional[float] = 0.1
"""
The weight to use for the MoE loss.
"""

@property
def effective_n_kv_heads(self) -> int:
if self.n_kv_heads is None:
Expand Down
7 changes: 5 additions & 2 deletions olmo/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ class ModuleType(StrEnum):


def init_weights(
config: ModelConfig,
module: Union[nn.Linear, nn.Embedding],
config: ModelConfig,
d: Optional[int] = None,
layer_id: Optional[int] = None,
std_factor: float = 1.0,
Expand Down Expand Up @@ -47,7 +47,10 @@ def init_weights(
std = std_factor / math.sqrt(d)
if layer_id is not None:
std = std / math.sqrt(2 * (layer_id + 1))
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
if hasattr(module, "weight"):
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
else:
nn.init.trunc_normal_(module, mean=0.0, std=std, a=-3 * std, b=3 * std)
elif config.init_fn == InitFnType.kaiming_normal:
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
elif config.init_fn == InitFnType.fan_in:
Expand Down
80 changes: 48 additions & 32 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@
else:
raise SystemExit("This script supports Python 3.8 or higher")

try:
from megablocks.layers.moe import MoE
from megablocks.layers.arguments import Arguments as MoEArgs
except ImportError:
log.warning("megablocks not installed, MoE layers will not be available.")

__all__ = [
"LayerNormBase",
"LayerNorm",
Expand All @@ -72,12 +78,6 @@

log = logging.getLogger(__name__)

try:
from megablocks.layers.moe import MoE
from megablocks.layers.arguments import Arguments as MoEArgs
except ImportError:
log.warning("megablocks not installed, MoE layers will not be available.")


def activation_checkpoint_function(cfg: ModelConfig):
preserve_rng_state = (
Expand Down Expand Up @@ -480,15 +480,15 @@ def reset_parameters(self):
if self.q_norm is not None:
self.q_norm.reset_parameters()
init_weights(
self.config,
self.attn_out,
self.config,
d=self.config.d_model,
layer_id=self.layer_id,
type_of_module=ModuleType.out_module,
)
init_weights(
self.config,
self.ff_out,
self.config,
d=self.ff_out.in_features,
layer_id=self.layer_id,
type_of_module=ModuleType.out_module,
Expand Down Expand Up @@ -633,7 +633,7 @@ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OLMoBl
elif config.block_type == BlockType.llama:
return OLMoLlamaBlock(layer_id, config, cache)
elif config.block_type == BlockType.moe:
return OLMoEBlock(layer_id, config, cache)
return OLMoEBlock(layer_id, config, cache)
else:
raise NotImplementedError(f"Unknown block type: '{config.block_type}'")

Expand All @@ -644,7 +644,7 @@ class OLMoEBlock(OLMoBlock):
(plus another skip connection).
"""
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
super().__init__()
nn.Module.__init__(self)
self.layer_id = layer_id
self.config = config
self.hidden_size = (
Expand Down Expand Up @@ -685,18 +685,23 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):

# MoE Block
moe_args = MoEArgs(
activation_fn=F.silu if 'glu' in config.activation_type.lower() else self.act,
mlp_type='glu' if 'glu' in config.activation_type.lower() else 'mlp',
hidden_size=config.d_model,
ffn_hidden_size=config.d_model*4,#int(self.act.output_multiplier * self.hidden_size),
moe_num_experts=8,#config.moe_num_experts,
moe_weight_parallelism=False,#config.moe_weight_parallelism,
moe_expert_model_parallelism=False,#config.moe_expert_model_parallelism,
moe_top_k=2,#config.moe_top_k,
moe_capacity_factor=1.25,#config.moe_capacity_factor,
moe_loss_weight=0.1,#config.moe_loss_weight,
device=torch.cuda.current_device(),
ffn_hidden_size=int(self.act.output_multiplier * self.hidden_size),
moe_num_experts=config.moe_num_experts,
# Handled by FSDP (https://github.com/databricks/megablocks/issues/57#issuecomment-1854594483)
moe_weight_parallelism=False,
# Not tested for now
moe_expert_model_parallelism=False,
moe_top_k=config.moe_top_k,
moe_capacity_factor=config.moe_capacity_factor,
moe_loss_weight=config.moe_loss_weight,
device=config.init_device,
# Handled by FSDP
bf16=False,
fp16=False,
init_method=partial(init_weights, config=config, d=config.d_model, layer_id=None, type_of_module=ModuleType.in_module),
)
self.ffn = MoE(moe_args)

Expand All @@ -713,14 +718,28 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
except ModuleNotFoundError:
pass

self.attn_norm = LayerNorm.build(config)
self.ff_norm = LayerNorm.build(config)

# Attention input projection. Projects x -> (q, k, v)
head_dim = config.d_model // config.n_heads
self.fused_dims = (
config.d_model,
config.effective_n_kv_heads * head_dim,
config.effective_n_kv_heads * head_dim,
)
self.att_proj = nn.Linear(
config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device
)

def reset_parameters(self):
if self.k_norm is not None:
self.k_norm.reset_parameters()
if self.q_norm is not None:
self.q_norm.reset_parameters()
init_weights(
self.config,
self.attn_out,
self.config,
d=self.config.d_model,
layer_id=self.layer_id,
type_of_module=ModuleType.out_module,
Expand All @@ -729,10 +748,7 @@ def reset_parameters(self):
self.ff_norm.reset_parameters()
# NOTE: the standard deviation for these weights does not depend on the layer.
init_weights(
self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
)
init_weights(
self.config, self.ffn, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
self.att_proj, self.config, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
)

def forward(
Expand Down Expand Up @@ -823,10 +839,10 @@ def reset_parameters(self):
self.ff_norm.reset_parameters()
# NOTE: the standard deviation for these weights does not depend on the layer.
init_weights(
self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
self.att_proj, self.config, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
)
init_weights(
self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
self.ff_proj, self.config, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
)

def forward(
Expand Down Expand Up @@ -928,10 +944,10 @@ def reset_parameters(self):
self.attn_norm.reset_parameters()
self.ff_norm.reset_parameters()
# NOTE: the standard deviation for these weights does not depend on the layer.
init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None)
init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None)
init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None)
init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None)
init_weights(self.q_proj, self.config, d=self.config.d_model, layer_id=None)
init_weights(self.k_proj, self.config, d=self.config.d_model, layer_id=None)
init_weights(self.v_proj, self.config, d=self.config.d_model, layer_id=None)
init_weights(self.ff_proj, self.config, d=self.config.d_model, layer_id=None)

def _scaled_dot_product_attention(
self,
Expand Down Expand Up @@ -1181,20 +1197,20 @@ def reset_parameters(self):
log.info("Initializing model parameters...")
# Top-level embeddings / linear layers.
init_weights(
self.config,
self.transformer.wte, # type: ignore
self.config,
std_factor=(0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0,
type_of_module=ModuleType.emb,
)
if hasattr(self.transformer, "wpe"):
init_weights(self.config, self.transformer.wpe, type_of_module=ModuleType.emb) # type: ignore
init_weights(self.transformer.wpe, self.config, type_of_module=ModuleType.emb) # type: ignore

# Top-level layer norm.
self.transformer.ln_f.reset_parameters() # type: ignore

# Output weights.
if hasattr(self.transformer, "ff_out"):
init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out) # type: ignore
init_weights(self.transformer.ff_out, self.config, type_of_module=ModuleType.final_out) # type: ignore

# Let the blocks handle themselves.
if self.config.block_group_size == 1:
Expand Down
9 changes: 9 additions & 0 deletions olmo/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
from .config import OptimizerType, SchedulerConfig, SchedulerType, TrainConfig
from .torch_util import get_default_device, is_distributed

try:
from megablocks.layers.mlp import MLP, SparseMLP
megablocks_available = True
except ImportError:
megablocks_available = False

__all__ = [
"Optimizer",
"LionW",
Expand Down Expand Up @@ -588,6 +594,7 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]]
"""
Separate parameters into weight decay and non weight decay groups.
"""
from megablocks.layers.mlp import MLP
param_groups: List[Dict[str, Any]]
param_group_defaults = {
"sharded": isinstance(model, FullyShardedDataParallel),
Expand Down Expand Up @@ -627,6 +634,8 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]]
decay.add(fpn)
else:
no_decay.add(fpn)
elif megablocks_available and pn.endswith(("w1", "w2")) and (isinstance(m, MLP) or isinstance(m, SparseMLP)):
decay.add(fpn)

# Validate that we've considered every parameter
inter_params = decay & no_decay
Expand Down
26 changes: 14 additions & 12 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,18 +193,20 @@ def fused_loss_fn(

self.loss_fn = fused_loss_fn

if self.cfg.block_type == BlockType.moe:#self.cfg.moe_freq > 0:
print(self.cfg)
if self.model.config.block_type == BlockType.moe:
# these MoEArgs are necessary for logging load balancing.
self.moe_args = MoEArgs(
hidden_size=self.cfg.d_model,
ffn_hidden_size=self.cfg.d_model * 4,
moe_num_experts=8,#self.cfg.moe_num_experts,
num_layers=self.cfg.n_layers,#if params.moe_freq > 0 and layer_id % params.moe_freq == 0:
moe_expert_model_parallelism=True,
moe_top_k=2,#self.cfg.moe_top_k,
device=torch.cuda.current_device(),
moe_capacity_factor=1.25,#self.cfg.moe_capacity_factor,
moe_loss_weight=0.1,#self.cfg.moe_loss_weight,
hidden_size=self.model.config.d_model,
ffn_hidden_size=self.model.config.d_model * 4,
moe_num_experts=self.model.config.moe_num_experts,
num_layers=self.model.config.n_layers,
# Not tested for nowe
moe_expert_model_parallelism=False,
moe_top_k=self.model.config.moe_top_k,
device=self.model.config.init_device,
moe_capacity_factor=self.model.config.moe_capacity_factor,
moe_loss_weight=self.model.config.moe_loss_weight,
fp16=False,
bf16=False,
)
Expand Down Expand Up @@ -666,7 +668,7 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor

ce_batch_loss = torch.tensor(0.0, device=self.device)
z_batch_loss = None if not self.cfg.softmax_auxiliary_loss else torch.tensor(0.0, device=self.device)
lb_batch_loss = None if self.cfg.block_type != BlockType.moe else torch.tensor(0.0, device=self.device)
lb_batch_loss = None if self.model.config.block_type != BlockType.moe else torch.tensor(0.0, device=self.device)
for micro_batch in micro_batches:
with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
# Run forward pass.
Expand All @@ -693,7 +695,7 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor
else:
loss = ce_loss

if self.cfg.block_type == BlockType.moe:
if self.model.config.block_type == BlockType.moe:
lb_batch_loss = batched_load_balancing_loss(self.moe_args)
clear_load_balancing_loss()
loss += lb_batch_loss
Expand Down

0 comments on commit 761d36a

Please sign in to comment.