之前详细讲解过transformer中的attention机制,这里在使用MultiHeadAttention
函数时发现pytorch已经实现了这个库,但是文档对此函数并没有很好的解释,这里从底层逻辑剖析一下此函数的使用
1. 自定义MultiHeadAttention函数
在介绍pytorch官方库之前,我们先手写一个相同的函数大概了解一下其运行逻辑
这里我们以NLP的输入为例,输入为[seq_length, batch_size, d_model]
,即第一维是词的个数,第二维是batch,第三维度是每个词的维度。例如我们有一个句子I love you
,那么其输入为[3, 1, 6]
,即三个词,每个词的维度为6。对于CV也是相同的道理,第一维度是图像patch的数量,第二维度是batch,第三维度是每个patch的维度。
1 | import torch |
上述代码中,MultiHeadAttention
一共有四个参数,三个Linear
获取query, key, value
,最后一个Linear
将attention后的结果映射到与输入相同的维度。这里我们的输入以及qkv的维度都相同,所以Linear
只是增加了增加了网络的复杂度
2. nn.MultiHeadAttention函数
在了解MultiHeadAttention
函数的底层逻辑之后,我们再来看看pytorch官方库是如何实现的
首先给出官方文档的解释
1 | import torch |
首先nn.MultiHeadAttention
中有四个参数,in_proj_weight
,in_proj_bias
,out_proj.weight
,out_proj.bias
,其中前两个参数是用来获取query, key, value
的,后两个参数是在attention之后再通过线性层将output映射到与输入相同维度的,这几点和第一部分完全相同,可以对照理解,然后我们讲解上述代码做了什么事情
- 修改了
in_proj_weight
和in_proj_bias
参数,我们将query, key, value
的转换矩阵分别全部初始化为1,2,3,同时将bias全部初始化为0 - 修改了
out_proj.weight
和out_proj.bias
参数,我们将output
的转换矩阵全部初始化为1,bias全部初始化为0,表示在attention之后的结果如果和输入维度不一样,会将其映射到和输入维度相同的维度
其中第一步参考下图理解

第二步参考下图理解,在第一步结束之后,把不同的head concat起来,然后通过线性层映射到和输入维度相同的维度

从上面对比可以看出nn.MultiHeadAttention
并没有进行LayerNorm
操作,只是进行多头注意力机制