<aside> <img src="/icons/snake_brown.svg" alt="/icons/snake_brown.svg" width="40px" /> This takes writing from multiple sources, so please bear with me on the notation.

</aside>

Problem statements

The quadratic bottleneck

Given an input $x_t$, the attention mechanism calculates the hidden state, $h_t$ as a function of the current token and all previous tokens. We note that there is a pairwise computation in the attention mechanism, specifically $softmax(\mathbf q \mathbf k^T)\mathbf v$. This pairwise communication means one forward pass has $O(n^2)$ time complexity (where $x_t$ has $n$ dimensions). As the context gets longer, the model gets slower.

Screenshot 2024-05-01 at 2.00.07 PM.png

Unable to model unbounded context

While transformers are popular for language modelling problems, they are limited when using to solve for other, more generic sequence to sequence problems. These problems exist in medical signals (EEG / ECG), speech and forecasting.

Say we want a model to learn a particular feature of the input signal that has unbounded context, like the (exponential) moving average. Due to the limits of context sizes in Transformer models, there is no way to learn this particular behaviour. RNNs / LSTMs also have a limited capability to model this behaviour due to information getting lost further backwards in a sequence.

Screenshot 2024-05-02 at 10.27.01 AM.png

Screenshot 2024-03-21 at 2.58.57 PM.png

S4

SSMs (State Space Models)

The SSM comprises of two equations:

$$ \begin{align} h'(t) &= \mathbf A h(t) + \mathbf B x(t) \\ y(t) &= \mathbf Ch(t) + \mathbf D x(t) \end{align} $$

The first equation maps the input sequence $x(t)$ to a hidden state $h(t)$, and the second equation maps the hidden state to an output.

Screenshot 2024-05-02 at 10.41.35 AM.png

Screenshot 2024-05-02 at 10.41.50 AM.png

Discretisation

The equations we saw above are in continuous space. What we need to do is to convert them to discrete-time space. That simply means that from a smooth curve, we have to observe / sample across small time periods. With that, we can now find an expression of $h(t)$

Screenshot 2024-05-02 at 10.46.51 AM.png

Starting with:

$$ \begin{align*} h'(t) &= \mathbf A h(t) + \mathbf B x(t) \\ y(t) &= \mathbf Ch(t) + \mathbf D x(t) \end{align*} $$

The discretised form of the hidden state is:

$$ \begin{align} h_t &= \overline{\mathbf A}h_{t-1} + \overline{\mathbf B} x_{t} \\ y_t &= \mathbf Ch_t \end{align} $$

where: