Skip to content

Commit

Permalink
[wip] init
Browse files Browse the repository at this point in the history
  • Loading branch information
dtunai committed May 27, 2024
1 parent cd75d25 commit 57646a4
Showing 1 changed file with 29 additions and 2 deletions.
31 changes: 29 additions & 2 deletions x_lstm_jax/lstm_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,33 @@
from typing import Tuple, Union, List


class CausalConv1D:
def __init__(self, features, kernel_size, dilation=1):
self.padding = (kernel_size - 1) * dilation
self.kernel_size = kernel_size
self.dilation = dilation
self.filter_shape = (features, 1, kernel_size)

def __call__(self, x):
batch_size, length, channels = x.shape
assert length >= self.padding, "Input length must be greater than padding."

pad_widths = [(0, 0), (0, 0), (0, self.padding)]
padded_x = jnp.pad(x, pad_widths)
padded_x = jnp.reshape(padded_x, (batch_size * length, channels, 1))

def conv_step(carry, input_slice):
padded_x, kernel = carry
conv_output = jnp.dot(input_slice, kernel)
return (padded_x, kernel), conv_output

initial_carry = (padded_x, self.filter_shape)
_, conv_output = lax.scan(conv_step, initial_carry, jnp.arange(length))
conv_output = jnp.reshape(conv_output, (batch_size, length, -1))
conv_output = conv_output[:, :, :-self.padding]
return conv_output


def block_diag(*arrs):
shapes = jnp.array([a.shape for a in arrs])
out = jnp.zeros(jnp.sum(shapes, axis=0))
Expand Down Expand Up @@ -51,7 +78,7 @@ def setup(self):
self.W_i = nn.Dense(features=self.head_num * self.head_dim)
self.W_o = nn.Dense(features=self.head_num * self.head_dim)
self.W_f = nn.Dense(features=self.head_num * self.head_dim)

# TODO: setup BlockLinear
self.R_z = nn.Dense(features=self.head_dim)
self.R_i = nn.Dense(features=self.head_dim)
Expand Down Expand Up @@ -83,7 +110,7 @@ def __call__(
x_t = self.inp_norm(seq)

if use_conv:
x_c = nn.Conv(features=1, kernel_size=(self.ker_size,))(x_t)
x_c = CausalConv1D(features=1, kernel_size=self.ker_size)(x_t)
x_c = nn.silu(x_c).squeeze()
else:
x_c = x_t
Expand Down

0 comments on commit 57646a4

Please sign in to comment.