diff --git a/CHANGELOG.md b/CHANGELOG.md index d3aa6caec..25c2b253f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,8 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added MMLU downstream evaluation tasks, with prompt variations. - Added support for PyTorch v2.2. - Added ability to show logs from all ranks - - +- Added option for QKV clipping. ### Changed diff --git a/olmo/config.py b/olmo/config.py index c0f26b08b..66e176a7e 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -243,6 +243,11 @@ class ModelConfig(BaseConfig): The number of self-attention heads. """ + clip_qkv: Optional[float] = None + """ + Clip QKV to this value when set. + """ + n_layers: int = 12 """ The number of layers/blocks. diff --git a/olmo/model.py b/olmo/model.py index a11eceb71..bd7be3097 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -452,6 +452,10 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): ) self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine) + # Make sure QKV clip coefficient is positive, otherwise it's not well-defined. + if config.clip_qkv is not None: + assert config.clip_qkv > 0 + # Activation function. self.act = Activation.build(config) assert (self.act.output_multiplier * self.hidden_size) % 1 == 0 @@ -680,11 +684,14 @@ def forward( # - for multi-query attn q: (batch_size, seq_len, d_model) # k, v: (batch_size, seq_len, d_model // n_heads) if self._activation_checkpoint_fn is not None: - q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split( - self.fused_dims, dim=-1 - ) + qkv = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)) else: - q, k, v = self.att_proj(self.attn_norm(x)).split(self.fused_dims, dim=-1) + qkv = self.att_proj(self.attn_norm(x)) + + if self.config.clip_qkv is not None: + qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + q, k, v = qkv.split(self.fused_dims, dim=-1) # Get attention scores. if self._activation_checkpoint_fn is not None: @@ -780,6 +787,11 @@ def forward( else: q, k, v, ff = self.fused_attn_ff_proj(self.norm(x)).split(self.fused_dims, dim=-1) + if self.config.clip_qkv is not None: + q.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + k.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + v.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + # Get attention scores. # shape: (B, T, C) if self._activation_checkpoint_fn is not None: @@ -896,6 +908,11 @@ def forward( k = self.k_proj(x_normed) v = self.v_proj(x_normed) + if self.config.clip_qkv is not None: + q.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + k.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + v.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + # Get attention scores. att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)