MultiHeadAttention函数的详细说明

之前详细讲解过transformer中的attention机制,这里在使用MultiHeadAttention函数时发现pytorch已经实现了这个库,但是文档对此函数并没有很好的解释,这里从底层逻辑剖析一下此函数的使用

1. 自定义MultiHeadAttention函数

在介绍pytorch官方库之前,我们先手写一个相同的函数大概了解一下其运行逻辑

这里我们以NLP的输入为例,输入为[seq_length, batch_size, d_model],即第一维是词的个数,第二维是batch,第三维度是每个词的维度。例如我们有一个句子I love you,那么其输入为[3, 1, 6],即三个词,每个词的维度为6。对于CV也是相同的道理,第一维度是图像patch的数量,第二维度是batch,第三维度是每个patch的维度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def attention(query, key, value, mask=None, dropout=None):
"Compute 'Scaled Dot Product Attention'"
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim = -1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn

class MultiHeadAttention(nn.Module):
"""
Multi head attention module

Args:
h_head: number of heads
d_model: dimension of input
dropout: dropout rate
"""
def __init__(self, h_head, d_model, dropout=0.1):
super(MultiHeadAttention, self).__init__()
assert d_model % h_head == 0
self.h_head = h_head
self.d_k = d_model // h_head

self.wq = nn.Linear(d_model, d_model)
self.wk = nn.Linear(d_model, d_model)
self.wv = nn.Linear(d_model, d_model)

self.fc = nn.Linear(d_model, d_model)

self.attn = None
self.dropout = nn.Dropout(p=dropout)

def forward(self, query, key, value, mask=None):
if mask is not None:
# Same mask applied to all h_head heads.
mask = mask.unsqueeze(1)
batch_size = query.size(1)

print ('Before transform query: ' + str(query.size())) # (batch_size, seq_length, d_model)

query = self.wq(query).view(batch_size, -1, self.h_head, self.d_k).transpose(1, 2)
key = self.wk(key).view(batch_size, -1, self.h_head, self.d_k).transpose(1, 2)
value = self.wv(value).view(batch_size, -1, self.h_head, self.d_k).transpose(1, 2)

print ('After transform query: ' + str(query.size()))
x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h_head * self.d_k)
x = self.fc(x)
return x, self.attn

n_head = 2
d_model = 6
batch_size = 1
seq_length = 3
model = MultiHeadAttention(n_head, d_model)

query = torch.randn([seq_length, batch_size, d_model]) # [3, 1, 6]
key = query
value = query
print ('Input size: ' + str(query.size()))
output, attn = model(query, key, value)
print ('Output size: ' + str(output.size()))

"""
Input size: torch.Size([3, 1, 6])
Before transform query: torch.Size([3, 1, 6])
After transform query: torch.Size([1, 2, 3, 3])
Output size: torch.Size([1, 3, 6])
"""

上述代码中,MultiHeadAttention一共有四个参数,三个Linear获取query, key, value,最后一个Linear将attention后的结果映射到与输入相同的维度。这里我们的输入以及qkv的维度都相同,所以Linear只是增加了增加了网络的复杂度

2. nn.MultiHeadAttention函数

在了解MultiHeadAttention函数的底层逻辑之后,我们再来看看pytorch官方库是如何实现的

首先给出官方文档的解释

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
import numpy as np


n_head = 2
d_model = 6
batch_size = 1
seq_length = 3

attention = nn.MultiheadAttention(d_model, n_head)
print(attention.in_proj_weight.size())
print(attention.in_proj_bias.size())
print(attention.out_proj.weight.size())
print(attention.out_proj.bias.size())

'''
torch.Size([18, 6])
torch.Size([18])
torch.Size([6, 6])
torch.Size([6])
'''

# 修改multiheadattention中的in_proj参数
wq = torch.Tensor(np.ones((6, 6)))
wk = torch.Tensor(np.ones((6, 6))) * 2
wv = torch.Tensor(np.ones((6, 6))) * 3
weight = torch.nn.Parameter(torch.concatenate([wq, wk, wv], dim=0))
attention.in_proj_weight.data = weight
attention.in_proj_bias.data = torch.nn.Parameter(torch.Tensor(np.zeros((18,))))

# 修改multiheadattention中的out_proj参数
fc_weight = torch.nn.Parameter(torch.Tensor(np.ones((6, 6))))
fc_bias = torch.nn.Parameter(torch.Tensor(np.zeros((6,))))
attention.out_proj.weight = fc_weight
attention.out_proj.bias = fc_bias


# 定义输入
x = torch.ones([seq_length, batch_size, d_model])

output, attn = attention(x, x, x, average_attn_weights=False) # 默认为True
print(output)
print(output.size())

print(attn)
print(attn.size())

'''
# output
tensor([[[108., 108., 108., 108., 108., 108.]],

[[108., 108., 108., 108., 108., 108.]],

[[108., 108., 108., 108., 108., 108.]]], grad_fn=<ViewBackward0>)
torch.Size([3, 1, 6])

# attention
tensor([[[[0.3333, 0.3333, 0.3333],
[0.3333, 0.3333, 0.3333],
[0.3333, 0.3333, 0.3333]],

[[0.3333, 0.3333, 0.3333],
[0.3333, 0.3333, 0.3333],
[0.3333, 0.3333, 0.3333]]]], grad_fn=<ViewBackward0>)
torch.Size([1, 2, 3, 3])
'''

首先nn.MultiHeadAttention中有四个参数,in_proj_weightin_proj_biasout_proj.weightout_proj.bias,其中前两个参数是用来获取query, key, value的,后两个参数是在attention之后再通过线性层将output映射到与输入相同维度的,这几点和第一部分完全相同,可以对照理解,然后我们讲解上述代码做了什么事情

  1. 修改了in_proj_weightin_proj_bias参数,我们将query, key, value的转换矩阵分别全部初始化为1,2,3,同时将bias全部初始化为0
  2. 修改了out_proj.weightout_proj.bias参数,我们将output的转换矩阵全部初始化为1,bias全部初始化为0,表示在attention之后的结果如果和输入维度不一样,会将其映射到和输入维度相同的维度

其中第一步参考下图理解

第二步参考下图理解,在第一步结束之后,把不同的head concat起来,然后通过线性层映射到和输入维度相同的维度

从上面对比可以看出nn.MultiHeadAttention并没有进行LayerNorm操作,只是进行多头注意力机制

Error: API rate limit exceeded for 34.199.149.85. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)