this refer to d2l

the paper is: Attention is all you need

所有 LLM 标签的文章, 基于本文. Transformer是LLM的基础架构, 多个Transformer拼接, 加上其他一些模块, 即可组成LLM

Background

如果使用不同大小的数据, 那么对数据的处理可能变得困难.

考虑数据库:

  • 设计一个操作q, operator on (k, v) pairs
  • 根据数据库的内容, 相同的query可能找到不同的答案
  • 在大型状态空间上的操作(代码执行)可能变得简单
  • 无需压缩或者简化数据库

使用torch训练

定义

where are scalar attention weights.

import torch
import utils
 
attention_weights = torch.eye(10).reshape((1, 1, 10, 10))
utils.show_heatmaps(attention_weights, x_label="Keys", y_label="Queries")

Some common kernels for attention pooling are:

Attention Pooling via Nadaraya-Watson Regression

define some training data. In the following we use the dependency

def f(val):
    return 2 * torch.sin(val) + val
 
n = 40
x_train, _ = torch.sort(torch.rand(n) * 5)
y_train = f(x_train) + torch.randn(n)
x_val = torch.arange(0, 5, 0.1)
y_val = f(x_val)
 
def nadaraya_watson(x_train, y_train, x_val, kernel):
    dists = x_train.reshape((-1, 1)) - x_val.reshape((1, -1))
    # Each column/row corresponds to each query/key
    k = kernel(dists).type(torch.float32)
    attention_w = k / k.sum(0)
    y_hat = y_train @ attention_w
    return y_hat, attention_w
 
def plot(x_train, y_train, x_val, y_val, kernels, names, attention=False):
    fig, axes = plt.subplots(1, 4, sharey=True, figsize=(12, 3))
    for kernel, name, ax in zip(kernels, names, axes):
        y_hat, attention_w = nadaraya_watson(x_train, y_train, x_val, kernel)
        if attention:
            pcm = ax.imshow(attention_w.detach().numpy(), cmap='Reds')
        else:
            ax.plot(x_val, y_hat)
            ax.plot(x_val, y_val, 'm--')
            ax.plot(x_train, y_train, 'o', alpha=0.5);
        ax.set_xlabel(name)
        if not attention:
            ax.legend(['y_hat', 'y'])
    if attention:
        fig.colorbar(pcm, ax=axes, shrink=0.7)
 
plot(x_train, y_train, x_val, y_val, kernels, names)
 
plot(x_train, y_train, x_val, y_val, kernels, names, attention=True)

Adapting Attention Pooling

we can replace the Gaussian kernel with one of a different width.

use

sigmas = (0.1, 0.2, 0.5, 1)
names = ['Sigma ' + str(sigma) for sigma in sigmas]
 
def gaussian_with_width(sigma):
    return (lambda x: torch.exp(-x**2 / (2*sigma**2)))
 
kernels = [gaussian_with_width(sigma) for sigma in sigmas]
plot(x_train, y_train, x_val, y_val, kernels, names)

Dot product attention

review the attention function from the Gaussian kernel:

assume that all the elements of the query and the key are independent and are identically drawn random variables with zero mean and unit variance. The dot product between both vectors has zero mean and a variance of

then, we arrive at the first commonly used attention function that is used:

and use softmax to normalize:

Masked Softmax Operation

Since we need to be able to deal with sequences of different length, it is necessary to pad with dummy tokens for shorter sequences.

Since we do not want blanks in our attention model, we simply need to limit to , for however long,

import math
 
import torch
from torch import nn
 
 
def masked_softmax(X, valid_lens):
    """Perform softmax operation by masking elements on the last axis."""
 
    # `X`: 3D tensor, `valid_lens`: 1D or 2D tensor
    def _sequence_mask(X: torch.Tensor, valid_lens, value=0):
        maxlen = X.size(1)
        mask = torch.arange(maxlen, dtype=torch.float32, device=X.device)[None, :] < valid_lens[:, None]
        X[~mask] = value
        return X
 
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)
 
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))
 
masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))

Batch Matrix Multiply

when we have mini-batches of queries and keys, we can compute the dot products of all the queries and keys in the mini-batch simultaneously using matrix multiplication.

Q = torch.ones((2, 3, 4))
K = torch.ones((2, 4, 5))
assert torch.bmm(Q, K).shape == (2, 3, 5)

Scaled Dot Product Attention

use algorithm below to computer attention score.

the reason why divided by is just because it performs better in practice.

class DotProductAttention(nn.Module):
    """Scaled dot product attention."""
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.attention_weights = None
    
    def forward(self, queries, keys, values, valid_lens=None):
        """
        :param queries: (batch_size, num_queries, d) 
        :param keys:  (batch_size, num_key_value_pairs, d)
        :param values: (batch_size, num_key_value_pairs, v)
        :param valid_lens: (batch_size, ) or (batch_size, num_queries)
        :return: the score of similarity
        """
        d = queries.shape[-1]
        # Swap the last two dimensions of keys with keys.transpose(1, 2)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

Assume that we have a mini-batch size of 2, a total of 10 keys and values, and that the dimension of the values is 4.

queries = torch.normal(0, 1, (2, 1, 2))
keys = torch.normal(0, 1, (2, 10, 2))
values = torch.normal(0, 1, (2, 10, 4))
valid_lens = torch.tensor([4, 6])
attention = DotProductAttention(dropout=0.5)
attention.eval()
assert attention(queries, keys, values, valid_lens).shape == (2, 1, 4)

check whether the attention weights actually vanish(消失) for anything beyond the forth and sixth respectively

utils.show_heatmaps(attention.attention_weights.reshape(1, 1, 2, 10), x_label="Keys", y_label="Queries")

Additive Attention

when queries and keys are vectors of different dimensions, we can either use a matrix to address the mismatch via , or we can use additive attention as scoring function.

where , , and

class AdditiveAttention(nn.Module):
    """Additive Attention"""
    def __init__(self, num_hiddens, dropout, **kwargs):
        super().__init__()
        self.W_k = nn.LazyLinear(num_hiddens, bias=False)
        self.W_q = nn.LazyLinear(num_hiddens, bias=False)
        self.w_v = nn.LazyLinear(1, bias=False)
        self.dropout = nn.Dropout(dropout)
        
        self.attention_weights = None
        
    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        features = torch.tanh(features)
        scores = self.w_v(features).squeeze(-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        
        return torch.bmm(self.dropout(self.attention_weights), values)

we pick queries, keys and values of size (2,1,20), (2,10,2) and (2,10,4) respectively.

queries = torch.normal(0, 1, (2, 1, 20))
keys = torch.normal(0, 1, (2, 10, 2))
values = torch.normal(0, 1, (2, 10, 4))
valid_lens = torch.tensor([2, 6])
 
attention = AdditiveAttention(num_hiddens=8, dropout=0.1)
attention.eval()
assert attention(queries, keys, values, valid_lens).shape == (2, 1, 4)
 
utils.show_heatmaps(attention.attention_weights.reshape(1, 1, 2, 10), x_label="Keys", y_label="Queries")

Bahdanau Attention Mechanism

原因: 原始的翻译模型是将整个输入序列编码为一个固定长度的向量,然后解码器将这个向量转换为输出序列。这种方法的问题是,如果输入序列很长,那么编码器可能会丢失一些信息。为了解决这个问题,Bahdanau等人提出了注意力机制。

这种注意力机制的关键思想是, 不是保留最终的序列信息,而是动态更新每个部分.

如:

是已经生成的文本作为queries, 而是原始文本(隐变量)同时作为key和value. 用于生成

import torch
from torch import nn
 
def masked_softmax(X, valid_lens):
    """Perform softmax operation by masking elements on the last axis."""
 
    # `X`: 3D tensor, `valid_lens`: 1D or 2D tensor
    def _sequence_mask(X: torch.Tensor, valid_lens, value=0):
        maxlen = X.size(1)
        mask = torch.arange(maxlen, dtype=torch.float32, device=X.device)[None, :] < valid_lens[:, None]
        X[~mask] = value
        return X
 
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)
 
 
class AdditiveAttention(nn.Module):
    """Additive Attention"""
 
    def __init__(self, num_hiddens, dropout, **kwargs):
        super().__init__()
        self.W_k = nn.LazyLinear(num_hiddens, bias=False)
        self.W_q = nn.LazyLinear(num_hiddens, bias=False)
        self.w_v = nn.LazyLinear(1, bias=False)
        self.dropout = nn.Dropout(dropout)
 
        self.attention_weights = None
 
    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        features = torch.tanh(features)
        scores = self.w_v(features).squeeze(-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
 
        return torch.bmm(self.dropout(self.attention_weights), values)
 
class Encoder(nn.Module):
    """The base encoder interface for the encoder--decoder architecture."""
 
    def __init__(self):
        super().__init__()
 
    # Later there can be additional arguments (e.g., length excluding padding)
    def forward(self, X, *args):
        raise NotImplementedError
 
 
class Seq2SeqEncoder(Encoder):  #@save
    """The RNN encoder for sequence-to-sequence learning."""
 
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)
        self.apply(utils.init_seq2seq)
 
    def forward(self, X, *args):
        # X shape: (batch_size, num_steps)
        embs = self.embedding(X.t().type(torch.int64))
        # embs shape: (num_steps, batch_size, embed_size)
        outputs, state = self.rnn(embs)
        # outputs shape: (num_steps, batch_size, num_hiddens)
        # state shape: (num_layers, batch_size, num_hiddens)
        return outputs, state
 
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
 
    def init_state(self, enc_all_outputs, enc_valid_lens):
        raise NotImplementedError
 
    def forward(self, X, state):
        raise NotImplementedError
 
 
class AttentionDecoder(Decoder):
    def __init(self):
        super().__init__()
 
    @property
    def attention_weights(self):
        raise NotImplementedError
 
 
class Seq2seqAttentionDecoder(AttentionDecoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0):
        super().__init__()
 
        self.attention = AdditiveAttention(num_hiddens, dropout)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
        self.dense = nn.LazyLinear(vocab_size)
        self.apply(utils.init_seq2seq)
 
        self._attention_weights = []
 
    def init_state(self, enc_all_outputs, enc_valid_lens):
        outputs, hidden_state = enc_all_outputs
        return outputs.permute(1, 0, 2), hidden_state, enc_valid_lens
 
    def forward(self, X, state):
        enc_outputs, hidden_state, enc_valid_lens = state
        X = self.embedding(X).permute(1, 0, 2)
        outputs, self._attention_weights = [], []
        for x in X:
            query = torch.unsqueeze(hidden_state[-1], dim=1)
            context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
            x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
            out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
            outputs.append(out)
            self._attention_weights.append(self.attention.attention_weights)
 
        outputs = self.dense(torch.cat(outputs, dim=0))
        return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]
 
    @property
    def attention_weights(self):
        return self._attention_weights
 
# Test
vocab_size, embed_size, num_hiddens, num_layers = 10, 8, 16, 2
batch_size, num_steps = 4, 7
encoder = Seq2SeqEncoder(vocab_size, embed_size, num_hiddens, num_layers)
decoder = Seq2seqAttentionDecoder(vocab_size, embed_size, num_hiddens, num_layers)
 
X = torch.zeros((batch_size, num_steps), dtype=torch.long)
 
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
assert output.shape == (batch_size, num_steps, vocab_size)
assert state[0].shape == (batch_size, num_steps, num_hiddens)
assert state[1][0].shape == (batch_size, num_hiddens)

Multi-Head Attention

sometimes we may want our model to combine knowledge from different behaviors of same attention mechanism, such as capturing dependencies of various ranges(e.g. shorter-range vs. longer-range)

Model

Given a query , a key , and value . each attention head is computed as

where , , are learnable parameters and is attention pooling, such as additive attention and scaled dot product attention.

the output of multi head attention is another linear transformer via

import math
 
import torch
from torch import nn
 
def masked_softmax(X, valid_lens):
    """Perform softmax operation by masking elements on the last axis."""
 
    # `X`: 3D tensor, `valid_lens`: 1D or 2D tensor
    def _sequence_mask(X: torch.Tensor, valid_lens, value=0):
        maxlen = X.size(1)
        mask = torch.arange(maxlen, dtype=torch.float32, device=X.device)[None, :] < valid_lens[:, None]
        X[~mask] = value
        return X
 
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)
 
 
class DotProductAttention(nn.Module):
    """Scaled dot product attention."""
 
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.attention_weights = None
 
    def forward(self, queries, keys, values, valid_lens=None):
        """
        :param queries: (batch_size, num_queries, d) 
        :param keys:  (batch_size, num_key_value_pairs, d)
        :param values: (batch_size, num_key_value_pairs, v)
        :param valid_lens: (batch_size, ) or (batch_size, num_queries)
        :return: the score of similarity
        """
        d = queries.shape[-1]
        # Swap the last two dimensions of keys with keys.transpose(1, 2)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)
class MultiHeadAttention(utils.Module):
    def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super().__init__()
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_k = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_v = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_o = nn.LazyLinear(num_hiddens, bias=bias)
 
    def transpose_qkv(self, X):
        """Transposition for parallel computation of multiple attention heads."""
        # Shape of input X: (batch_size, no. of queries or key-value pairs,
        # num_hiddens). Shape of output X: (batch_size, no. of queries or
        # key-value pairs, num_heads, num_hiddens / num_heads)
        X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
        # Shape of output X: (batch_size, num_heads, no. of queries or key-value
        # pairs, num_hiddens / num_heads)
        X = X.permute(0, 2, 1, 3)
        # Shape of output: (batch_size * num_heads, no. of queries or key-value
        # pairs, num_hiddens / num_heads)
        return X.reshape(-1, X.shape[2], X.shape[3])
 
    def transpose_output(self, X):
        """Reverse the operation of transpose_qkv."""
        X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
        X = X.permute(0, 2, 1, 3)
        return X.reshape(X.shape[0], X.shape[1], -1)
 
    def forward(self, queries, keys, values, valid_lens):
        queries = self.transpose_qkv(self.W_q(queries))
        keys = self.transpose_qkv(self.W_k(keys))
        values = self.transpose_qkv(self.W_v(values))
 
        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
 
        output = self.attention(queries, keys, values, valid_lens)
        output_concat = self.transpose_output(output)
        return self.W_o(output_concat)
 
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
assert attention(X, Y, Y, valid_lens).shape == (batch_size, num_queries, num_hiddens)

Self Attention

the queries is same as keys and values

import math
 
import torch
from torch import nn
 
def masked_softmax(X, valid_lens):
    """Perform softmax operation by masking elements on the last axis."""
 
    # `X`: 3D tensor, `valid_lens`: 1D or 2D tensor
    def _sequence_mask(X: torch.Tensor, valid_lens, value=0):
        maxlen = X.size(1)
        mask = torch.arange(maxlen, dtype=torch.float32, device=X.device)[None, :] < valid_lens[:, None]
        X[~mask] = value
        return X
 
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)
 
 
class DotProductAttention(nn.Module):
    """Scaled dot product attention."""
 
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.attention_weights = None
 
    def forward(self, queries, keys, values, valid_lens=None):
        """
        :param queries: (batch_size, num_queries, d) 
        :param keys:  (batch_size, num_key_value_pairs, d)
        :param values: (batch_size, num_key_value_pairs, v)
        :param valid_lens: (batch_size, ) or (batch_size, num_queries)
        :return: the score of similarity
        """
        d = queries.shape[-1]
        # Swap the last two dimensions of keys with keys.transpose(1, 2)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)
 
 
class MultiHeadAttention(utils.Module):
    def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super().__init__()
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_k = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_v = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_o = nn.LazyLinear(num_hiddens, bias=bias)
 
    def transpose_qkv(self, X):
        """Transposition for parallel computation of multiple attention heads."""
        # Shape of input X: (batch_size, no. of queries or key-value pairs,
        # num_hiddens). Shape of output X: (batch_size, no. of queries or
        # key-value pairs, num_heads, num_hiddens / num_heads)
        X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
        # Shape of output X: (batch_size, num_heads, no. of queries or key-value
        # pairs, num_hiddens / num_heads)
        X = X.permute(0, 2, 1, 3)
        # Shape of output: (batch_size * num_heads, no. of queries or key-value
        # pairs, num_hiddens / num_heads)
        return X.reshape(-1, X.shape[2], X.shape[3])
 
    def transpose_output(self, X):
        """Reverse the operation of transpose_qkv."""
        X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
        X = X.permute(0, 2, 1, 3)
        return X.reshape(X.shape[0], X.shape[1], -1)
 
    def forward(self, queries, keys, values, valid_lens):
        queries = self.transpose_qkv(self.W_q(queries))
        keys = self.transpose_qkv(self.W_k(keys))
        values = self.transpose_qkv(self.W_v(values))
 
        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
 
        output = self.attention(queries, keys, values, valid_lens)
        output_concat = self.transpose_output(output)
        return self.W_o(output_concat)
 
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones(batch_size, num_queries, num_hiddens)
assert attention(X, X, X, valid_lens).shape ==(batch_size, num_queries, num_hiddens)

Positional Encoding

self-attention uses parallel computation, and doesn’t preserve the order of the sequence.

reason:

  • each position has a different position code
  • smooth value
  • Monotonically decreasing
class PositionalEncoding(nn.Module):
    """positional encoding"""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        self.P = torch.zeros(1, max_len, num_hiddens)
        X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)
        
    def forward(self, X: torch.Tensor):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)
 
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
X = pos_encoding(torch.zeros(1, num_steps, encoding_dim))
P = pos_encoding.P[:, :X.shape[1], :]
utils.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel="Row (position)", figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])

Relative Positional Encoding

for fixed offset

denoting , then

where the projection matrix doesn’t depend on any position index