Автор: Александр Лекомцев Редактура: Андрей Шадриков

Введение

В статье «MoH: Multi-Head Attention as Mixture-of-Head Attention» переносят идею Mixture-of-Experts (MoE) с архитектуры моделей на архитектуру блока Multi-Head Attention (MHA). Авторы пришли к этой идее после того, как выяснили: большая часть голов может быть убрана из MHA без значительного снижения перформанса. Но вместо прунинга голов авторы предлагают учить Router для каждого attention-блока, который будет выбирать, какие головы обработают данные.

Давайте разберёмся подробнее с архитектурой MoH, затем обсудим сценарии применения и метрики.

Архитектура Mixture-of-Head Attention

Рисунок 1. Схемы для стандартного Multi-Head Attention (a) и предложенного MoH Attention (b)

Рисунок 1. Схемы для стандартного Multi-Head Attention (a) и предложенного MoH Attention (b)

Multi-Head Attention обычно записывается через конкатенацию attention каждой отдельной головы, назовём такую запись concatenation form:

$$ \begin{align*} \text{MultiHead}(X, X') &= \text{Concat}(H^1, H^2, \dots, H^h) W_O, \\ H^i &= \text{Attention}(X W_Q^i, X' W_K^i, X' W_V^i), \\ \text{Attention}(Q, K, V) &= \text{Softmax}\left(\frac{Q K^\top}{\sqrt{d_k}}\right)V \end{align*} $$

Размерности матриц в уравнении выше:

$$ \begin{align*} Q &= X W_Q^i, & W_Q^i &\in \mathbb{R}^{d_\text{in} \times \frac{d_k}{h}}, \\ K &= X' W_K^i, & W_K^i &\in \mathbb{R}^{d_\text{in} \times \frac{d_k}{h}}, \\ V &= X' W_V^i, & W_V^i &\in \mathbb{R}^{d_\text{in} \times \frac{d_v}{h}}, \\ X &\in \mathbb{R}^{n \times d_{\mathrm{in}}}, & H^i &\in \mathbb{R}^{n \times \frac{d_v}{h}}.

\end{align*} \\ d_\text{in}: \text{input feature dimension}, \\ d_k: \text{query/key projection dimension}, \\ d_v: \text{value projection dimension}, \\ n: \text{num input tokens}, \\ h: \text{num heads}.

$$

Но при этом мы можем записать матрицу $W_O$ как конкатенацию h строк:

$$ \begin{align*} \begin{bmatrix} W_O^1 \\ W_O^2 \\ \vdots \\ W_O^h \end{bmatrix} = W_O, \quad W_O \in \mathbb{R}^{d_v \times d_\text{out}}, \quad W_O^i \in \mathbb{R}^{\frac{d_v}{h} \times d_\text{out}} \\ \end{align*}

$$

И в итоге получить новую форму для записи Multi-Head Attention через сумму, назовём её summation form:

$$ \begin{align*} \text{MultiHead}(X, X') = \text{Concat}(H^1, H^2, \dots, H^h) W_O = \\ = [H^1, H^2, \dots, H^h] W_O = [H^1, H^2, \dots, H^h]\begin{bmatrix} W_O^1 \\ W_O^2 \\ \vdots \\ W_O^h \end{bmatrix} = \\ = H^1W_O^1+H^2W_O^2 + \dots + H^hW_O^h = \sum_{i=1}^h H^i W_O^i \end{align*} $$

Таким образом, мы представили Multi-Head Attention как сумму матриц, а дополнительных ограничений на свойства $W_O$ не появилось — это просто две разные записи одного и того же выражения.

Теперь давайте перейдём к Mixture-of-Head Attention. Для этого заменим сумму в выражении выше на взвешенную сумму, то есть каждую матрицу будем умножать на скаляр $g_i$, который равен нулю, только если i-я голова не выбрана в качестве «эксперта»:

$$ \begin{align*} \text{MoH}(X, X') = \sum_{i=1}^h g_iH^i W_O^i \end{align*} $$

Кроме того, разделим головы на два типа — shared и routed. Shared будут использоваться всегда (то есть $g_i$ ≠ 0 всегда для shared), и мы ожидаем, что они будут отвечать за какие-то «общие знания», которые нужны независимо от конкретной задачи и токена. Routed будут выбираться под каждый токен. То есть при прогоне через модель будут использоваться все shared и только часть routed heads. Для простоты пусть первые $h_s$ по нумерации голов — это shared, а остальные от $h_s + 1$ до $h$ — routed.