Mega attention breakdown

Dmytro Kuzmenko
8 min readOct 8, 2022

--

Greetings! I recently talked at the AI House Paper Club on Mega attention — a simple, efficient, and effective neural architecture used as a drop-in replacement for regular multi-head attention. The recording is soon to arrive, and in the meantime, I am sharing my detailed breakdown here on Medium.

Intro

Attention provides the key mechanism that captures contextual information from the entire sequence by modeling pairwise interactions between the inputs at every timestep.

However, there are two common drawbacks in the design of the attention mechanism:

1) weak inductive bias;

2) quadratic computational complexity.

First, the attention mechanism does not assume prior knowledge of the patterns of dependencies between tokens (e.g. positional inductive bias), instead learning to predict the pairwise attention weights directly from data.

Second, the cost to compute and store the attention weights is quadratic in the length of the input sequences. Recent studies have shown the limitations of applying Transformers to long sequence tasks, w.r.t both accuracy and efficiency. Surely other recent works have already tackled this issue and managed to reduce the complexity to linear, but the authors of Mega have a different take on the matter.

TL;DR

In this work, the authors propose a moving average equipped gated attention mechanism (Mega) to solve the two weaknesses simultaneously. The key idea is to incorporate inductive biases into the attention mechanism across the timestep dimension, by leveraging the classic exponential moving average (EMA) approach (Hunter, 1986).

EMA captures local dependencies that exponentially decay over time, and has been widely used in time series data modeling.

Self-attention

A quick reminder of the SA mechanism concept:

Self-attention in detail. Vaswani et al., “Attention is all you need”, 2017.

Self-attention is a key concept in all the sequence-modeling tasks and the respective architectures, i.e. transformers.

EMA — Exponential moving average

A moving average is a classic approach for sequential data modeling, which has been widely used in time series data to smooth out short-term fluctuations and highlight long-term trends or cycles.

The Exponential Moving Average (EMA), a special case of the moving average, applies weighting factors that decrease exponentially. Formally, an EMA recursively calculates the output sequence Y:

α and (1 — α) as discount factors for exponential decay.

Using an EMA places a strong inductive bias on the learning of pairwise dependencies: the dependency weight between two tokens decreases exponentially over time with an input-agnostic decay factor α.

This property favors local dependencies, and limits long-distance dependencies. The computation of EMA can be represented as n individual convolutions, which can be computed efficiently using fast Fourier transforms (FFTs).

Example of EMA usage in time series analysis (source)

Merging both concepts

EMA and attention mechanisms each have their own limitations, despite their wide applications and impressive successes in sequence modeling.

By leveraging their properties to complement each other, the authors propose to embed an EMA into the calculation of the attention matrix A. The resulting model enjoys the benefit of strong inductive bias while maintaining the capacity to learn complex dependency patterns.

Moreover, this integration enables the design of a computationally efficient chunk-wise attention mechanism with linear complexity w.r.t sequence length.

Mega in detail

Multi-dimensional EMA

Mega introduces a modification of the standard EMA, named multi-dimensional damped EMA, to improve its flexibility and capacity.

Damped EMA key concept inspired by previous works.

Multi-dimensional Damped EMA is then introduced to further improve the expressiveness of EMA. Specifically, the authors perform dimensional expansions on different parameters (alpha and delta). They also expand each dimension of the input sequence X individually into h dimensions via an expansion matrix β.

Moving Average Equipped Gated Attention — Mega

The gated attention mechanism in Mega adopts the Gated Recurrent Unit (GRU; Cho et al. (2014)) and Gated Attention Unit (GAU; Hua et al. (2022)) as the backbone architectures, with an EMA-based sub-layer embedded into the calculation of the attention matrix.

[1] Formally, we first use the output from the EMA to compute the shared representation in GAU:

EMA is initially applied to the input, the shared representations are then computed in GAU-fashion

where X` can be regarded as the updated or contextual input, as it encodes contextual information through EMA. Z is the shared representation with z dimensions, with projection matrix W and bias term b. φ(silu) is the self-gated activation function (SiLU).

[2] Following GAU, the query and key sequences are computed by applying per-dimension scalars and offsets to Z, and the value sequence is from the original X:

Attention (O) is calculated in a traditional fashion. Only Q and K matrices are calculated via offsets, scalars, and shared representations. V is simply run through the attention function.

The SHGA unit is rather elegant — post-EMA input (x') is used in the calculation of queries and keys (Q, K), while pure input (x) forms values (V).

Final stages of attention

The concept of reset and update gates is already known and is reintroduced in the paper. Both gates are calculated and used to form a candidate activation output (H). Finally, the update gate participates in the output (Y) production in an EMA-like fashion.
Gated attention sub-layer equipped with EMA

Single-head or multi-head?

A very elaborate theorem and proposition to substitute the rather popular multi-head attention (MHA) with single-head gated attention (SHGA).

Basically, authors have experimented and concluded that adding gates to SHA performs just as well as MHSA, which is quite a find. The computation complexity is severely reduced, while dependency-capturing power remains the same if not better.

This is definitely the kind of progress we are experiencing and moving towards. Expanding existing modules, not by the power magnitude, but rather maintaining the performance, and reducing the space/memory/computation requirements.

Mega Blocks

The Mega layer (moving average equipped gated attention) is used as a drop-in replacement for regular attention in Transformer. It is followed by position-wise feed-forward networks (FFNs) and normalization layers to compose one Mega block.

The feedforward and the expanded dimensions are set to 2d (instead of 4d as in previous works) to retain a similar model size.

Mega-chunk: Mega with Linear Complexity

Authors address the second problem of attention mechanism, namely quadratic computational complexity, via proposing Mega-chunk — a variant of Mega with linear complexity, which simply applies attention to each local chunk of fixed length.

Specifically, they first split the sequences of Q, K, and V into chunks of length c. e.g. Q = {Q1, . . . , Qk}, where k = n/c is the number of chunks (keys and values are split in the same way). The attention operation (O) is individually applied to each chunk, yielding linear complexity O(kc² ) = O(nc) w.r.t n.

However, this method suffers from the critical limitation of losing contextual information from other chunks. Fortunately, the EMA sub-layer in Mega mitigates this problem by capturing local contextual information near each token, whose outputs are used as the inputs to the attention sub-layer. As a result, the effective context being exploited by chunk-wise attention can go beyond the chunk boundary.

Benchmarks and miscellaneous

Attention functions

The softmax function is the most common choice for the attention function. So et al. (2021) recently introduced the squared ReLU function via architecture search techniques, which has shown faster convergence speed and competitive generalization performance on language tasks.

However, one issue of relu2 is that neither its range nor its gradient is bounded, leading to unstable model training.

Besides performance improvements, the authors also investigated the stability of the two attention functions. They conducted experiments on the LRA Pathfinder with Mega models with the two functions. Laplace was observed to be much more stable than ReLU2.

To address this issue, authors propose a new attention function based on the Laplace function:

Benchmarks — Long-range Arena (LRA)

Mega results on LRA. Decent speed as well as performance.

On all six tasks, Mega substantially outperforms all the baselines.

Authors evaluate Mega-chunk on each task, by setting the chunk size c = 128 for all the tasks, except Path-X where c = 4096. Mega-chunk consistently performs well, particularly on the three language tasks.

They also examine the speed and memory efficiency of Mega on the byte-level classification task with an input length of 4K. Mega-chunk is highly efficient, is about 5.5 times faster, and consumes only 13% as much memory as the vanilla Transformer. It is interesting to see that Mega with full attention field is also much more efficient than Transformer, benefiting from single-head gated attention.

SC-Raw and WikiText-103

XFM, S4, and Mega comparison

Consistently outperforming S4 and XFM on all tasks but raw speech classification.

Image classification

Mega obtains about 0.6% accuracy improvement over DeiT-B
Very similar training approach to DeiT-B with only more epochs and warmup. Interesting to note that gradient clipping is used for Mega, as relu² may lead to exploding gradients.

Summary

Thank you for making it this far!

Mega is a simple, efficient, and effective neural architecture used as a drop-in replacement for regular multi-head attention. It has a lot of potential to generate new SOTA papers and further enhance bleeding-edge models.

By leveraging the classic exponential moving average (EMA) approach, Mega is capable of incorporating stronger inductive biases into the attention mechanism. Moreover, the EMA approach enables the design of Mega-chunk, an efficient variant of Mega with linear complexity.

I do hope you enjoyed the read and retrieved some new knowledge for yourself, thank you for your time.

References

  1. “Attention is all you need”
  2. Investopedia on EMA
  3. Mega paper
  4. Mega official implementation
  5. SiLU activation function
  6. Gated attention unit
  7. Gated recurrent unit
  8. ReLU squared
  9. EMA

--

--