Transformer模型技术原理详解

Transformer模型技术原理详解

引言

Transformer模型自2017年提出以来,已经成为自然语言处理(NLP)领域的基础架构。本文将从技术实现的角度,详细解析Transformer的核心原理和关键组件。

1. Transformer整体架构

1.1 架构概览

Transformer采用编码器-解码器(Encoder-Decoder)架构,但不同于传统的序列模型,它完全基于注意力机制:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class Transformer(nn.Module):
def __init__(self,
src_vocab_size,
tgt_vocab_size,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1):
super().__init__()

# 编码器
self.encoder = TransformerEncoder(
src_vocab_size,
d_model,
nhead,
num_encoder_layers,
dim_feedforward,
dropout
)

# 解码器
self.decoder = TransformerDecoder(
tgt_vocab_size,
d_model,
nhead,
num_decoder_layers,
dim_feedforward,
dropout
)

# 输出层
self.output_layer = nn.Linear(d_model, tgt_vocab_size)

def forward(self, src, tgt):
# src: [batch_size, src_len]
# tgt: [batch_size, tgt_len]

# 编码器前向传播
encoder_output = self.encoder(src) # [batch_size, src_len, d_model]

# 解码器前向传播
decoder_output = self.decoder(tgt, encoder_output)

# 生成最终输出
output = self.output_layer(decoder_output)

return output

2. 核心组件实现

2.1 自注意力机制

自注意力是Transformer的核心创新,它允许模型直接建模序列中任意位置之间的依赖关系:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, nhead, dropout=0.1):
super().__init__()
assert d_model % nhead == 0

self.d_model = d_model
self.nhead = nhead
self.d_k = d_model // nhead

# 线性变换层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)

self.dropout = nn.Dropout(dropout)

def scaled_dot_product_attention(self, Q, K, V, mask=None):
# Q, K, V: [batch_size, nhead, seq_len, d_k]

# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

# 应用mask(如果有)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)

# 注意力权重
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)

# 计算输出
output = torch.matmul(attention_weights, V)

return output, attention_weights

def forward(self, query, key, value, mask=None):
batch_size = query.size(0)

# 线性变换
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)

# 重塑为多头形式
Q = Q.view(batch_size, -1, self.nhead, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.nhead, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.nhead, self.d_k).transpose(1, 2)

# 自注意力计算
output, attention = self.scaled_dot_product_attention(Q, K, V, mask)

# 重塑回原始维度
output = output.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model)

# 最终线性变换
output = self.W_o(output)

return output, attention

2.2 位置编码

由于自注意力机制本身不包含位置信息,需要额外的位置编码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_length=5000):
super().__init__()

# 创建位置编码矩阵
pe = torch.zeros(max_seq_length, d_model)
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model)
)

# 计算正弦和余弦位置编码
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

# 添加batch维度并注册为缓冲区
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)

def forward(self, x):
# x: [batch_size, seq_len, d_model]
return x + self.pe[:, :x.size(1)]

2.3 前馈神经网络

每个编码器和解码器层都包含一个前馈神经网络:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()

self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)

def forward(self, x):
# x: [batch_size, seq_len, d_model]

# 第一个线性变换
x = self.w_1(x)
x = F.relu(x)
x = self.dropout(x)

# 第二个线性变换
x = self.w_2(x)

return x

3. 训练与优化

3.1 损失函数

Transformer通常使用交叉熵损失,但要注意处理填充标记:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def compute_loss(output, target, pad_idx):
# 创建mask,忽略填充标记
mask = (target != pad_idx).float()

# 计算交叉熵损失
criterion = nn.CrossEntropyLoss(reduction='none')
loss = criterion(
output.view(-1, output.size(-1)),
target.view(-1)
)

# 应用mask
loss = loss * mask.view(-1)

# 计算平均损失
return loss.sum() / mask.sum()

3.2 学习率调度

Transformer使用特殊的学习率调度策略:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class TransformerLRScheduler:
def __init__(self, optimizer, d_model, warmup_steps=4000):
self.optimizer = optimizer
self.d_model = d_model
self.warmup_steps = warmup_steps
self.step_num = 0

def step(self):
self.step_num += 1
lr = self.compute_lr()
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr

def compute_lr(self):
step = self.step_num
arg1 = step ** (-0.5)
arg2 = step * (self.warmup_steps ** -1.5)

return self.d_model ** (-0.5) * min(arg1, arg2)

4. 实现技巧与优化

4.1 注意力优化

  1. 稀疏注意力

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    def sparse_attention(Q, K, V, sparsity_threshold=0.1):
    scores = torch.matmul(Q, K.transpose(-2, -1))

    # 只保留top-k的注意力权重
    top_k = int(scores.size(-1) * sparsity_threshold)
    top_scores, _ = torch.topk(scores, top_k, dim=-1)
    threshold = top_scores[..., -1:]

    # 创建mask
    mask = scores >= threshold
    scores = scores.masked_fill(~mask, -1e9)

    attention = torch.softmax(scores, dim=-1)
    return torch.matmul(attention, V)
  2. 局部注意力

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    def local_attention(Q, K, V, window_size=16):
    batch_size, num_heads, seq_len, d_k = Q.size()

    # 创建局部注意力mask
    local_mask = torch.ones(seq_len, seq_len).triu(-window_size).tril(window_size)

    scores = torch.matmul(Q, K.transpose(-2, -1)) * local_mask
    attention = torch.softmax(scores, dim=-1)

    return torch.matmul(attention, V)

4.2 内存优化

  1. 梯度检查点

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    class CheckpointedTransformerLayer(nn.Module):
    def __init__(self, d_model, nhead):
    super().__init__()
    self.attention = MultiHeadAttention(d_model, nhead)
    self.feed_forward = PositionwiseFeedForward(d_model, d_model * 4)

    def forward(self, x):
    def custom_forward(x):
    return self.attention(x, x, x)[0]

    # 使用梯度检查点
    x = checkpoint.checkpoint(custom_forward, x)
    x = self.feed_forward(x)
    return x
  2. 混合精度训练

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    def train_step(model, optimizer, scheduler, batch, scaler):
    optimizer.zero_grad()

    with torch.cuda.amp.autocast():
    output = model(batch.src, batch.tgt)
    loss = compute_loss(output, batch.tgt, pad_idx)

    # 使用梯度缩放器
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    scheduler.step()
    return loss.item()

5. 性能评估与调试

5.1 注意力可视化

1
2
3
4
5
6
7
8
9
10
11
12
13
def visualize_attention(attention_weights, tokens, save_path=None):
"""可视化注意力权重"""
plt.figure(figsize=(10, 10))
sns.heatmap(
attention_weights,
xticklabels=tokens,
yticklabels=tokens,
cmap='viridis'
)
plt.title('Attention Weights Visualization')
if save_path:
plt.savefig(save_path)
plt.show()

5.2 性能分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def analyze_model_performance(model, test_loader):
metrics = {
'loss': [],
'accuracy': [],
'attention_entropy': []
}

model.eval()
with torch.no_grad():
for batch in test_loader:
output, attention = model(batch.src, batch.tgt)

# 计算损失
loss = compute_loss(output, batch.tgt, pad_idx)
metrics['loss'].append(loss.item())

# 计算准确率
pred = output.argmax(dim=-1)
acc = (pred == batch.tgt).float().mean()
metrics['accuracy'].append(acc.item())

# 计算注意力熵
entropy = -(attention * torch.log(attention + 1e-9)).sum(-1).mean()
metrics['attention_entropy'].append(entropy.item())

return {k: np.mean(v) for k, v in metrics.items()}

总结

Transformer通过创新的自注意力机制和精心设计的架构,实现了序列处理任务的突破性进展。理解其实现细节不仅有助于更好地使用这一模型,也为设计新的架构提供了重要参考。


本文会持续更新,欢迎在评论区分享你的见解和经验!