diff --git a/x_lstm_jax/lstm_blocks.py b/x_lstm_jax/lstm_blocks.py index 1515704..addf792 100644 --- a/x_lstm_jax/lstm_blocks.py +++ b/x_lstm_jax/lstm_blocks.py @@ -4,6 +4,33 @@ from typing import Tuple, Union, List +class CausalConv1D: + def __init__(self, features, kernel_size, dilation=1): + self.padding = (kernel_size - 1) * dilation + self.kernel_size = kernel_size + self.dilation = dilation + self.filter_shape = (features, 1, kernel_size) + + def __call__(self, x): + batch_size, length, channels = x.shape + assert length >= self.padding, "Input length must be greater than padding." + + pad_widths = [(0, 0), (0, 0), (0, self.padding)] + padded_x = jnp.pad(x, pad_widths) + padded_x = jnp.reshape(padded_x, (batch_size * length, channels, 1)) + + def conv_step(carry, input_slice): + padded_x, kernel = carry + conv_output = jnp.dot(input_slice, kernel) + return (padded_x, kernel), conv_output + + initial_carry = (padded_x, self.filter_shape) + _, conv_output = lax.scan(conv_step, initial_carry, jnp.arange(length)) + conv_output = jnp.reshape(conv_output, (batch_size, length, -1)) + conv_output = conv_output[:, :, :-self.padding] + return conv_output + + def block_diag(*arrs): shapes = jnp.array([a.shape for a in arrs]) out = jnp.zeros(jnp.sum(shapes, axis=0)) @@ -51,7 +78,7 @@ def setup(self): self.W_i = nn.Dense(features=self.head_num * self.head_dim) self.W_o = nn.Dense(features=self.head_num * self.head_dim) self.W_f = nn.Dense(features=self.head_num * self.head_dim) - + # TODO: setup BlockLinear self.R_z = nn.Dense(features=self.head_dim) self.R_i = nn.Dense(features=self.head_dim) @@ -83,7 +110,7 @@ def __call__( x_t = self.inp_norm(seq) if use_conv: - x_c = nn.Conv(features=1, kernel_size=(self.ker_size,))(x_t) + x_c = CausalConv1D(features=1, kernel_size=self.ker_size)(x_t) x_c = nn.silu(x_c).squeeze() else: x_c = x_t