Skip to content

Commit

Permalink
Deltanet test (IBM#26)
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 authored Sep 26, 2024
1 parent 9add88a commit f49d632
Show file tree
Hide file tree
Showing 13 changed files with 412 additions and 58 deletions.
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ install:
install-dev:
pip install --extra-index-url https://download.pytorch.org/whl/nightly/cpu -e .
pip install -r requirements-dev.txt

git clone https://github.com/sustcsonglin/flash-linear-attention
cd flash-linear-attention
pip install .
cd ..

test:
pytest tests
Expand Down
142 changes: 142 additions & 0 deletions dolomite_engine/hf_models/models/rnn_dolomite/attention/convolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import warnings
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from .....utils import is_causal_conv1d_available, is_einops_available


if is_einops_available():
from einops import rearrange


if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update


class ParameterizedShortConvolution(nn.Conv1d):
def __init__(
self,
hidden_size: int,
kernel_size: int,
bias: bool = False,
activation: nn.Module = nn.Identity(),
use_fast_conv1d: bool = True,
std: float | None = None,
) -> None:
self.std = std

super().__init__(
in_channels=hidden_size,
out_channels=hidden_size,
kernel_size=kernel_size,
groups=hidden_size,
bias=bias,
padding=kernel_size - 1,
)

self.hidden_size = hidden_size
self.activation = activation

if not is_causal_conv1d_available():
if use_fast_conv1d:
raise RuntimeError(
"Please either install `causal-conv1d>=1.4.0` to enable fast causal short convolution CUDA kernel "
"or set `use_fast_conv1d` to False"
)
else:
warnings.warn(
"The naive Pytorch verison is very slow in practice, "
"please run `pip install causal-conv1d>=1.4.0` to install fast causal short convolution CUDA kernel",
category=ImportWarning,
)
self.use_fast_conv1d = use_fast_conv1d

def extra_repr(self):
s = "{in_channels}, {out_channels}, kernel_size={kernel_size}" ", stride={stride}"
if self.padding != (0,) * len(self.padding):
s += ", padding={padding}"
if self.dilation != (1,) * len(self.dilation):
s += ", dilation={dilation}"
if self.output_padding != (0,) * len(self.output_padding):
s += ", output_padding={output_padding}"
if self.groups != 1:
s += ", groups={groups}"
if self.bias is None:
s += ", bias=False"
if self.padding_mode != "zeros":
s += ", padding_mode={padding_mode}"
if not self.use_fast_conv1d:
s += ", use_fast_conv1d={use_fast_conv1d}"
return s.format(**self.__dict__)

def forward(
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, cache: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
x (`torch.Tensor`):
Tensor of shape `[batch_size, seq_len, hidden_size]`
mask (`Optional[torch.Tensor]`):
Attention mask dealing with padded positions.
cache (`Optional[torch.Tensor]`):
Previous cache tensor of shape `[batch_size, hidden_size, kernel_size]`,
Returns:
Tensor of shape `[batch_size, seq_len, hidden_size]`. The `cache` (if provided) is updated inplace.
"""

if mask is not None:
x = x.mul_(mask.unsqueeze(-1))
if cache is not None and x.shape[1] == 1:
return self.step(x, cache)
x = rearrange(x, "b l d -> b d l")
# Update state (B D W)
if cache is not None:
cache.copy_(F.pad(x, (self.kernel_size[0] - x.shape[-1], 0)))
if self.use_fast_conv1d:
x = causal_conv1d_fn(
x=x,
weight=rearrange(self.weight, "d 1 w -> d w"),
bias=self.bias,
activation=self.activation,
)
else:
x = self._conv_forward(x, self.weight, self.bias)[..., : x.shape[-1]]
x = self.activation(x)
return rearrange(x, "b d l -> b l d")

def step(self, x: torch.Tensor, cache: torch.Tensor):
assert x.shape[1] == 1, "Only support decoding with 1 token at a time for now"

x = x.squeeze(1)
if self.use_fast_conv1d:
x = causal_conv1d_update(
x=x,
conv_state=cache,
weight=rearrange(self.weight, "d 1 w -> d w"),
bias=self.bias,
activation=self.activation,
)
else:
dtype = x.dtype
cache.copy_(torch.roll(cache, shifts=-1, dims=-1))
cache[:, :, -1] = x
x = torch.sum(cache * rearrange(self.weight, "d 1 w -> d w"), dim=-1)
if self.bias is not None:
x = x + self.bias
x = self.activation(x).to(dtype)
return x.unsqueeze(1)

@property
def state_size(self) -> int:
return self.hidden_size * self.kernel_size

def reset_parameters(self) -> None:
if self.std is None:
super().reset_parameters()
else:
nn.init.normal_(self.weight, mean=0, std=self.std)
if self.bias is not None:
self.bias.zero_()
42 changes: 11 additions & 31 deletions dolomite_engine/hf_models/models/rnn_dolomite/attention/deltanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,41 +9,17 @@
from ....config import CommonConfig
from ....enums import InitMethod
from ....modeling_utils import ParameterizedLinear, get_normalization_function
from .convolution import ParameterizedShortConvolution


if is_einops_available():
from einops import rearrange

if is_fla_available():
from fla.models.utils import Cache as FLACache
from fla.modules import ShortConvolution
from fla.ops.delta_rule import chunk_delta_rule, fused_chunk_delta_rule, fused_recurrent_delta_rule


if is_fla_available():

class ParameterizedShortConvolution(ShortConvolution):
def __init__(
self,
hidden_size: int,
kernel_size: int,
bias: bool = False,
activation: str = "silu",
use_causal_conv: bool = True,
std: float | None = None,
) -> None:
self.std = std
super().__init__(hidden_size, kernel_size, bias, activation, use_causal_conv)

def reset_parameters(self) -> None:
if self.std is None:
super().reset_parameters()
else:
nn.init.normal_(self.weight, mean=0, std=self.std)
if self.bias is not None:
self.bias.zero_()


# https://github.com/IDSIA/recurrent-fwp/blob/master/algorithmic/layers.py#L86C1-L146C1
class DeltaNet(nn.Module):
def __init__(
Expand Down Expand Up @@ -100,18 +76,22 @@ def __init__(
std_conv = initializer_range
self.conv_size = conv_size
if share_conv_kernel:
self.h_conv1d = ParameterizedShortConvolution(
self.hidden_size, conv_size, activation=None, std=std_conv
)
self.h_conv1d = ParameterizedShortConvolution(self.hidden_size, conv_size, std=std_conv)
else:
self.q_conv1d = ParameterizedShortConvolution(
self.key_dim, conv_size, activation="silu" if qk_activation == "silu" else None, std=std_conv
self.key_dim,
conv_size,
activation=nn.SiLU() if qk_activation == "silu" else nn.Identity(),
std=std_conv,
)
self.k_conv1d = ParameterizedShortConvolution(
self.key_dim, conv_size, activation="silu" if qk_activation == "silu" else None, std=std_conv
self.key_dim,
conv_size,
activation=nn.SiLU() if qk_activation == "silu" else nn.Identity(),
std=std_conv,
)
self.v_conv1d = ParameterizedShortConvolution(
self.value_dim, conv_size, activation="silu", std=std_conv
self.value_dim, conv_size, activation=nn.SiLU(), std=std_conv
)

self.use_beta = use_beta
Expand Down
9 changes: 4 additions & 5 deletions dolomite_engine/hf_models/models/rnn_dolomite/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class RNNDolomitePreTrainedModel(PreTrainedModelMixin):
config_class = RNNDolomiteConfig
layer_class = RNNDolomiteBlock
_no_split_modules = ["RNNDolomiteBlock"]
_supports_sdpa = False

def __init__(self, config: RNNDolomiteConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
Expand All @@ -36,7 +35,7 @@ def _init_model(self, config: RNNDolomiteConfig, **kwargs) -> None:
self.m_emb = config.m_emb
self.initializer_range = config.initializer_range

self.attention_patterns = self.mapping_attention_patterns(config.attention_patterns)
self.attention_pattern = self.mapping_attention_pattern(config.attention_pattern)

self.head_dim = divide_if_divisible(
self.embed_dim,
Expand All @@ -52,7 +51,7 @@ def _init_model(self, config: RNNDolomiteConfig, **kwargs) -> None:
self.layer_class(
config,
normalization_implementation=self.normalization_implementation,
attention_implementation=self.attention_patterns[i],
attention_pattern=self.attention_pattern[i],
use_padding_free_transformer=self._use_padding_free_transformer,
layer_idx=i,
)
Expand All @@ -72,9 +71,9 @@ def _init_model(self, config: RNNDolomiteConfig, **kwargs) -> None:
# Initialize weights and apply final processing
self.post_init()

def mapping_attention_patterns(self, attention_patterns: str) -> list[str]:
def mapping_attention_pattern(self, attention_pattern: str) -> list[str]:
attention_implementation_list = []
for pattern in attention_patterns:
for pattern in attention_pattern:
if pattern == "a":
attention_implementation_list.append(self.attention_implementation)
elif pattern == "d":
Expand Down
6 changes: 3 additions & 3 deletions dolomite_engine/hf_models/models/rnn_dolomite/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
class RNNDolomiteConfig(CommonConfig):
model_type = "rnn_dolomite"

def __init__(self, attention_patterns: str | None = None, **kwargs) -> None:
def __init__(self, attention_pattern: str | None = None, **kwargs) -> None:
super().__init__(**kwargs)

assert len(attention_patterns) == self.n_layer, "Attention patterns must be specified for each layer"
self.attention_patterns = attention_patterns
assert len(attention_pattern) == self.n_layer, "Attention patterns must be specified for each layer"
self.attention_pattern = attention_pattern
8 changes: 4 additions & 4 deletions dolomite_engine/hf_models/models/rnn_dolomite/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(
self,
config: RNNDolomiteConfig,
normalization_implementation: str,
attention_implementation: str,
attention_pattern: str,
use_padding_free_transformer: bool,
layer_idx: int | None = None,
) -> None:
Expand All @@ -38,12 +38,12 @@ def __init__(
normalization_implementation=normalization_implementation,
)

if attention_implementation == "DeltaNet":
if attention_pattern == "DeltaNet":
self.attn = DeltaNet(config=config, layer_idx=layer_idx)
elif attention_implementation == "flash_attention_2":
elif attention_pattern == "flash_attention_2":
self.attn = RNNFlashAttention2(config, True, layer_idx)
else:
raise ValueError(f"Attention implementation {attention_implementation} not supported.")
raise ValueError(f"Attention pattern {attention_pattern} not supported.")

self.ln_2 = get_normalization_function(
config.normalization_function,
Expand Down
7 changes: 4 additions & 3 deletions dolomite_engine/optimization/params_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..hf_models.modeling_utils import Attention
from ..hf_models.models.gpt_dolomite.layer import MLP
from ..hf_models.models.moe_dolomite.moe import SparseMoE
from ..hf_models.models.rnn_dolomite.attention import DeltaNet
from ..model_wrapper import ModelWrapper
from ..utils import log_rank_0

Expand Down Expand Up @@ -80,10 +81,10 @@ def get_mup_group_with_names(model: ModelWrapper, optimizer_class_args: dict) ->

# collect parameters with mup learning rate
for module_name, module in model.named_modules():
if isinstance(module, (Attention, MLP, SparseMoE)):
if isinstance(module, (Attention, MLP, SparseMoE, DeltaNet)):
for param_name, param in module.named_parameters():
# we don't add bias to mup group
if not param_name.endswith(".bias"):
# we don't add bias or norms to mup group
if not (param_name.endswith(".bias") or "norm" in param_name):
# add name of module to name of subparam
mup_params[f"{module_name}.{param_name}"] = param
elif isinstance(module, (nn.LayerNorm, nn.RMSNorm)) or module.__class__.__name__.lower().endswith("norm"):
Expand Down
1 change: 1 addition & 0 deletions dolomite_engine/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .mixed_precision import normalize_dtype_string, string_to_torch_dtype, torch_dtype_to_string
from .packages import (
is_apex_available,
is_causal_conv1d_available,
is_deepspeed_available,
is_einops_available,
is_fla_available,
Expand Down
14 changes: 14 additions & 0 deletions dolomite_engine/utils/packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,20 @@ def is_kernel_hyperdrive_available() -> bool:
return _IS_KHD_AVAILABLE


try:
import causal_conv1d

_IS_CAUSAL_CONV1D_AVAILABLE = True
except ImportError:
_IS_CAUSAL_CONV1D_AVAILABLE = False

warn_rank_0("causal-conv1d is not installed")


def is_causal_conv1d_available() -> bool:
return _IS_CAUSAL_CONV1D_AVAILABLE


@run_rank_n
def log_environment() -> None:
packages = sorted(["{}=={}".format(d.metadata["Name"], d.version) for d in distributions()])
Expand Down
Loading

0 comments on commit f49d632

Please sign in to comment.