diff --git a/_config.yml b/_config.yml
index 930a8fa..2f19ada 100644
--- a/_config.yml
+++ b/_config.yml
@@ -2,11 +2,11 @@
# Site settings
# -----------------------------------------------------------------------------
-title: blank # the website title (if blank, full name will be used instead)
-first_name: You
-middle_name: R.
-last_name: Name
-email: you@example.com
+title: Goomba AI Lab # the website title (if blank, full name will be used instead)
+first_name: Goomba
+middle_name: AI
+last_name: Lab
+email:
description: > # the ">" symbol means to ignore newlines until "footer_text:"
A simple, whitespace theme for academics. Based on [*folio](https://github.com/bogoli/-folio) design.
footer_text: >
@@ -17,8 +17,8 @@ keywords: jekyll, jekyll-theme, academic-website, portfolio-website # add your o
lang: en # the language of your site (for example: en, fr, cn, ru, etc.)
icon: ⚛️ # the emoji used as the favicon (alternatively, provide image name in /assets/img/)
-url: https://alshedivat.github.io # the base hostname & protocol for your site
-baseurl: /al-folio # the subpath of your site, e.g. /blog/. Leave blank for root
+url: https://gu-group.github.io # the base hostname & protocol for your site
+baseurl: # the subpath of your site, e.g. /blog/. Leave blank for root
last_updated: false # set to true if you want to display last updated in the footer
impressum_path: # set to path to include impressum link in the footer, use the same path as permalink in a page, helps to conform with EU GDPR
back_to_top: true # set to false to disable the back to top button
@@ -130,7 +130,7 @@ bing_site_verification: # out your bing-site-verification ID (Bing Webmaster)
# Blog
# -----------------------------------------------------------------------------
-blog_name: al-folio # blog_name will be displayed in your blog page
+blog_name: Goomblog # blog_name will be displayed in your blog page
blog_description: a simple whitespace theme for academics
permalink: /blog/:year/:title/
lsi: true # produce an index for related posts
diff --git a/_posts/2024-05-30-mamba2.md b/_posts/2024-05-30-mamba2.md
new file mode 100644
index 0000000..3413e36
--- /dev/null
+++ b/_posts/2024-05-30-mamba2.md
@@ -0,0 +1,419 @@
+---
+layout: distill
+title: State Space Duality (Mamba-2) Part 1 - The Model
+description:
+tags:
+giscus_comments: true
+date: 2024-05-27
+featured: true
+
+authors:
+ - name: Albert Gu
+ url:
+ affiliations:
+ name: Carnegie Mellon University
+ - name: Tri Dao
+ url:
+ affiliations:
+ name: Princeton
+
+bibliography: 2018-12-22-distill.bib
+
+# Optionally, you can add a table of contents to your post.
+# NOTES:
+# - make sure that TOC names match the actual section names
+# for hyperlinks within the post to work correctly.
+# - we may want to automate TOC generation in the future using
+# jekyll-toc plugin (https://github.com/toshimaru/jekyll-toc).
+toc:
+ - name: Equations
+ # if a section has subsections, you can add them as follows:
+ # subsections:
+ # - name: Example Child Subsection 1
+ # - name: Example Child Subsection 2
+ - name: The SSD Model
+ - name: Code
+
+---
+
+
+Since the release of [Mamba], we've been overwhelmed by the community response.
+
+(give list of examples of applications and understanding papers)
+
+(link to Aviv's compilation)
+
+Despite its ... we weren't satisfied with
+
+### Problem 1 (Understanding):
+From a conceptual standpoint, one of the reasons we found SSMs so fascinating is how they just feel _fundamental_. One way this is exemplified is how they have rich ties to many major paradigms of sequence models.
+As developed in our earlier works on structured SSMs [cite LSSL and thesis], they seem to capture the essence of continuous, convolutional, and recurrent sequence models -- all wrapped up in a simple and elegant model.
+
+But of course, aside from these, there is another major sequence model paradigm: the ubiquitous **attention** mechanism (and variants).
+
+> Question 1: **What are the conceptual connections between SSMs and attention?**
+
+**Problem 2 (Efficiency):**
+From a computational standpoint,
+despite the work that went into making Mamba fast -- in particular, its hardware-aware selective scan implementation -- it is still much less hardware-efficient than mechanisms such as attention.
+The missing piece is that modern accelerators such as GPUs and TPUs are highly specialized for matrix multiplications (matmuls),
+While this is not a problem for inference, which is bottlenecked by different types of considerations, this can be a big deal during training time.
+For example, an end-to-end Mamba-1 model is XX times slower than an equivalent Transformer.
+
+Question 2: **Can we speed up the training of Mamba models by recasting them as matrix multiplications?**
+
+These are the main questions that SSD (a.k.a. Mamba-2) tries to address.
+
+## Outline
+Other Topics:
+- How does it relate to SSMs
+- How does it relate to attention
+- future work from each viewpoint
+
+## The State Space Dual Model
+
+SSD refers to both a general framework, as well as a specific model.
+The **state space dual model** or SSD model
+itself really isn't so scary - we'll first provide a self-contained description of the SSD layer in isolation here before elaborating on some of the theoretical connections!
+
+### The Linear (SSM) Mode
+
+SSD starts from the same selective state space model as Mamba:
+
+$$
+\begin{aligned}
+h_{t} &= A_t h_{t-1} + B_t x_t \\
+y_t &= C_t^{\top} y_t
+\end{aligned}
+$$
+
+To recap, a structured state space model (SSM) defines a map from $x \in \mathbb{R}^\mathtt{T} \to y \in \mathbb{R}^\mathtt{T}$ :warning:.
+Think of $x_t$ and $y_t$ as being scalars, and the hidden state $h_t$ as an $\mathtt{N}$-dimensional vector, where $\mathtt{N}$ is an independent hyperparameter called the state size, state dimension, or state expansion factor.
+A *selective* state space model allows the $A, B, C$ SSM parameters to vary across time.
+We'll think of them as tensors with shapes $\mathtt{(T, N, N)}$, $\mathtt{(T, N)}$, and $\mathtt{(T, N)}$ respectively.
+
+Structured SSMs require $A$ to have structure to be efficiently computable, such as the most commonly used diagonal structure :warning:. In this case $A$ has shape $\mathtt{(T, N)}$ where only the diagonal elements of the $\mathtt{N} \times \mathtt{N}$ matrices are stored.
+
+#### SSD: Scalar Structured SSM
+The original Mamba (or more precisely its core "S6" layer) is exactly a selective SSM with diagonal structure.
+**The SSD layer of Mamba-2 makes only one simple modification**: it restricts the diagonal $A$ even further to a *scalar times identity* structure; in other words the diagonal elements of $A$ must all be the same.
+In this case $A$ can be represented with shape just $\mathtt{(T)}$ and one can also identify $A_t$ as just a scalar (and thus we will sometimes denote it $a_t$).
+
+#### Multihead SSMs
+
+Here, we think of $X$ as a tensor of shape $\mathtt{(T, P)}$ where $\mathtt{T}$ is the sequence (time) dimension and $\mathtt{P}$ is the "head dimension".We will ignore the batch dimension throughout this presentation.
+
+We can notate the general (selective) state space model as
+\begin{equation}
+\label{eq:ssm}
+Y^\mathtt{(T,P)} = \mathsf{SSM}(A^\mathtt{(T,...)}, B^\mathtt{(T,N)}, C^\mathtt{(T,N)})(X^\mathtt{(T,P)})
+\end{equation}
+
+Axes of variation include the structure on $A$, which affects its parameter shape, and the state dimension $\mathtt{N}=\mathtt{d\_state}$ and state dimension $\mathtt{P}=\mathtt{d\_head}$.
+
+#### Efficiency
+
+The reason why SSMs are interesting is because computing it as a recurrence requires maintaining a *constant-size state* (size $\mathtt{N}$) and scales *linearly in the sequence length* $\mathtt{T}$.
+But as mentioned above, the raw FLOPs don't reflect actual speed in practice...
+
+
+### The Quadratic (Attention) Mode
+
+Let's switch tacks and forget about state space models for a moment.
+Given the same $(A, B, C)$ tensors above with the same shapes $(\mathtt{T})$, \mathtt{(T, N)}$, and \mathtt{(T, N)}$,
+let's define a different object.
+
+First, we'll define
+
+$$
+ L =
+ \mathsf{1SS}(a_{0:T})
+ =
+ \begin{bmatrix}
+ 1 & \\
+ a_1 & 1 & \\
+ a_2a_1 & a_2 & 1 \\
+ \vdots & \vdots & \ddots & \ddots \\
+ a_{T-1}\dots a_1 & a_{T-1}\dots a_2 & \dots & a_{T-1} & 1 \\
+ \end{bmatrix}
+ .
+$$
+
+Then, let's define the following matrix
+
+\begin{equation}
+\label{eq:linear-attention}
+M = L \circ C B^\top \in \mathbb{R}^{\mathtt{(T,T)}}
+\end{equation}
+
+Finally, $M$ encodes a sequence transformation
+$x \in \mathbb{R}^\mathtt{T} \to y \in \mathbb{R}^\mathtt{T}$
+just as how we defined [SSMs](#the-linear-ssm-mode) above.
+
+What's special about this?
+Well, you may notice that it looks very similar to an attention computation.
+In fact, if all $a_t = 1$, then $L$ is simply the lower-triangular *causal mask* and \ref{eq:linear-attention} is exactly **causal linear attention** :warning: :
+$$
+Y = (L \circ Q K^\top) V
+$$
+
+This is exactly the same as equation \eqref{eq:linear-attention} if we rename $(C, B, X) \mapsto (Q, K, V)$!
+
+#### Efficiency
+
+### Best of Both Worlds: the Hybrid Mode
+Computationally, one can use either formulation to compute the model. Loosely speaking, the attention form is faster during training because it's dominated by matrix multiplications, while the SSM form is preferred during autoregressive inference.
+
+In the next two sections [LINK], we'll present two broad frameworks with which to understand the state space dual model.
+Each of them will both prove the equivalence of these two formulations, but each is much more general, and we'll discuss other consequences of the frameworks.
+
+If you just want to use the model, stop here!
+In the rest of this post, we'll give an overview of the theoretical aspects of the SSD framework.
+
+### State Space Duality
+
+The so-called "duality" refers to the fact that the two models defined in XX and XX are in fact *exactly the same model*, which is a particular function $(A, B, C, X) \mapsto Y$ with tensor shapes specified above.
+We'll show this fact in two completely different ways, both of which are actually much more general and each quite illuminating.
+
+If you take our word for it, though, then SSD is relatively simple to understand in contrast to either SSMs or attention.
+
+#### SSD vs. State Space Models
+Compared to previous SSMs, SSD is pretty much the same as the core layer of Mamba but with even more structure on the recurrent $A$ matrices.
+- Mamba (S6) uses diagonal structure on $A$ with a head dimension of $\mathtt{P}=1$.
+- Mamba-2 (SSD) uses scalar-times-identity structure on $A$ with a head dimension of $\mathtt{P}>1$ (something like $\mathtt{P}=64$ by default).
+
+In particular, this can be viewed as weight-tied in two ways:
+- By restricting the diagonal structure of $A$ to scalar-times-identity, the scalar recurrence dynamics are tied across all $\mathtt{N}$ elements of the state space.
+- These dynamics are also shared across all $\mathtt{P}$ channels of a given head.
+
+In other words, a single SSM head has total state size $\mathtt{P} \times \mathtt{N)}$,
+which are each governed by separate scalar recurrences in Mamba but are controlled by a single shared recurrence in Mamba-2.
+
+Why make these restrictions? The main motivation is efficiency: these changes are necessary to be able to view the model in its [[dual attention form](#the-quadratic-attention-mode)], which allows matrix multiplications to be used.
+
+> ##### The Bottom Line: Mamba-2 vs. Mamba-1
+>
+> Compared to Mamba-1, Mamba-2 allows **much larger state dimensions** (from $\mathtt{N}=16$ in Mamba-1 to $\mathtt{N}=64,$ or $\mathtt{N}=256$ or even higher in Mamba-2) while simultaneously being **much faster during training**.
+{: .block-tip}
+
+But can this hurt us? There's some intuition to believe that it shouldn't.
+One of the main reasons for the selectivity (e.g. $A$ that depends on the input $X$) introduced in Mamba
+is to let the SSM be able to control whether to remember or ignore particular pieces of information;
+for example, if a filler "um" is encountered in a text transcript.
+But if such information should be ignored, then the entire state can ignore it together, and so it should be okay if the state's dynamics are shared across all features.
+
+Empirically, we haven't found evidence that the restricted expressivity of Mamba-2 might hurt, but the jury's still out!
+From one perspective, Mamba-2 isn't *strictly* better than Mamba-1: while it's a dramatic improvement from a *training* perspective, Mamba-1 might be better from a pure *inference* perspective.
+Since inference speed of SSMs is entirely governed by the state dimension, if one wants to maximize performance for a target inference efficiency (i.e. for a particular state size $\mathtt{N}$), then the increased expressivity of Mamba-1 might be better.
+We haven't fully analyzed the (theoretical or empirical) tradeoffs here, and think this would be a cool direction for the community to dig in more.
+
+#### SSD vs. Attention
+
+Compared, to standard (self-)attention, SSD also only has two differences:
+1. The softmax normalization is dropped.
+2. A separate elementwise mask matrix is applied multiplicatively.
+
+The first difference can be interpreted as what reduces the effective state size of the model from infinite to finite, and improves its efficiency from quadratic to linear.
+
+The second difference is what distinguishes SSD from standard linear attention.
+One way to think of the mask is as **input-dependent relative positional encodings**.
+Because of the mask (definition :warning:), the standard attention score $Q_i K_j$ is attenuated by a score $a_{i:j}^\times = a_i \dots a_{j+1}$ which can be interpreted as a discount factor based on how far apart the positions $i$ and $j$ are. This interpretation was concurrently espoused by Tobias Katsch's GateLoop paper :warning:
+This is the key factor that encodes the "selectivity" of Mamba.
+
+
+## SSD Viewpoint 1 (structured matrix transformations)
+
+What is known in the literature as a (triangular) **semiseparable matrix**.
+
+## SSD Viewpoint 2 (structured attention)
+
+
+
+## Code
+
+{% highlight python linenos %}
+
+def test():
+ return None
+
+{% endhighlight %}
+
+{% highlight python %}
+
+def segsum(x):
+ """Naive segment sum calculation. exp(segsum(A)) produces a 1-SS matrix,
+ which is equivalent to a scalar SSM."""
+ T = x.size(-1)
+ x_cumsum = torch.cumsum(x, dim=-1)
+ x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
+ mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
+ x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
+ return x_segsum
+
+def ssd(X, A, B, C, block_len=64, initial_states=None):
+ """
+ Arguments:
+ X: (batch, length, n_heads, d_head)
+ A: (batch, length, n_heads)
+ B: (batch, length, n_heads, d_state)
+ C: (batch, length, n_heads, d_state)
+ Return:
+ Y: (batch, length, n_heads, d_head)
+ """
+ assert X.dtype == A.dtype == B.dtype == C.dtype
+ assert X.shape[1] % block_len == 0
+
+ # Rearrange into blocks/chunks
+ X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
+
+ A = rearrange(A, "b c l h -> b h c l")
+ A_cumsum = torch.cumsum(A, dim=-1)
+
+ # 1. Compute the output for each intra-chunk (diagonal blocks)
+ L = torch.exp(segsum(A))
+ Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
+
+ # 2. Compute the state for each intra-chunk
+ # (right term of low-rank factorization of off-diagonal blocks; B terms)
+ decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
+ states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
+
+ # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
+ # (middle term of factorization of off-diag blocks; A terms)
+ if initial_states is None:
+ initial_states = torch.zeros_like(states[:, :1])
+ states = torch.cat([initial_states, states], dim=1)
+ decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
+ new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
+ states, final_state = new_states[:, :-1], new_states[:, -1]
+
+ # 4. Compute state -> output conversion per chunk
+ # (left term of low-rank factorization of off-diagonal blocks; C terms)
+ state_decay_out = torch.exp(A_cumsum)
+ Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
+
+ # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
+ Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p")
+ return Y, final_state
+
+{% endhighlight %}
+
+Backticks:
+
+```javascript
+def segsum(x):
+ """Naive segment sum calculation. exp(segsum(A)) produces a 1-SS matrix,
+ which is equivalent to a scalar SSM."""
+ T = x.size(-1)
+ x_cumsum = torch.cumsum(x, dim=-1)
+ x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
+ mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
+ x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
+ return x_segsum
+
+def ssd(X, A, B, C, block_len=64, initial_states=None):
+ """
+ Arguments:
+ X: (batch, length, n_heads, d_head)
+ A: (batch, length, n_heads)
+ B: (batch, length, n_heads, d_state)
+ C: (batch, length, n_heads, d_state)
+ Return:
+ Y: (batch, length, n_heads, d_head)
+ """
+ assert X.dtype == A.dtype == B.dtype == C.dtype
+ assert X.shape[1] % block_len == 0
+
+ # Rearrange into blocks/chunks
+ X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
+
+ A = rearrange(A, "b c l h -> b h c l")
+ A_cumsum = torch.cumsum(A, dim=-1)
+
+ # 1. Compute the output for each intra-chunk (diagonal blocks)
+ L = torch.exp(segsum(A))
+ Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
+
+ # 2. Compute the state for each intra-chunk
+ # (right term of low-rank factorization of off-diagonal blocks; B terms)
+ decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
+ states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
+
+ # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
+ # (middle term of factorization of off-diag blocks; A terms)
+ if initial_states is None:
+ initial_states = torch.zeros_like(states[:, :1])
+ states = torch.cat([initial_states, states], dim=1)
+ decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
+ new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
+ states, final_state = new_states[:, :-1], new_states[:, -1]
+
+ # 4. Compute state -> output conversion per chunk
+ # (left term of low-rank factorization of off-diagonal blocks; C terms)
+ state_decay_out = torch.exp(A_cumsum)
+ Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
+
+ # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
+ Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p")
+ return Y, final_state
+```
+
+``:
+
+
+def segsum(x):
+ """Naive segment sum calculation. exp(segsum(A)) produces a 1-SS matrix,
+ which is equivalent to a scalar SSM."""
+ T = x.size(-1)
+ x_cumsum = torch.cumsum(x, dim=-1)
+ x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
+ mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
+ x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
+ return x_segsum
+
+def ssd(X, A, B, C, block_len=64, initial_states=None):
+ """
+ Arguments:
+ X: (batch, length, n_heads, d_head)
+ A: (batch, length, n_heads)
+ B: (batch, length, n_heads, d_state)
+ C: (batch, length, n_heads, d_state)
+ Return:
+ Y: (batch, length, n_heads, d_head)
+ """
+ assert X.dtype == A.dtype == B.dtype == C.dtype
+ assert X.shape[1] % block_len == 0
+
+ # Rearrange into blocks/chunks
+ X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
+
+ A = rearrange(A, "b c l h -> b h c l")
+ A_cumsum = torch.cumsum(A, dim=-1)
+
+ # 1. Compute the output for each intra-chunk (diagonal blocks)
+ L = torch.exp(segsum(A))
+ Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
+
+ # 2. Compute the state for each intra-chunk
+ # (right term of low-rank factorization of off-diagonal blocks; B terms)
+ decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
+ states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
+
+ # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
+ # (middle term of factorization of off-diag blocks; A terms)
+ if initial_states is None:
+ initial_states = torch.zeros_like(states[:, :1])
+ states = torch.cat([initial_states, states], dim=1)
+ decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
+ new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
+ states, final_state = new_states[:, :-1], new_states[:, -1]
+
+ # 4. Compute state -> output conversion per chunk
+ # (left term of low-rank factorization of off-diagonal blocks; C terms)
+ state_decay_out = torch.exp(A_cumsum)
+ Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
+
+ # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
+ Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p")
+ return Y, final_state
+