之前详细讲解过transformer中的attention机制,这里在使用MultiHeadAttention函数时发现pytorch已经实现了这个库,但是文档对此函数并没有很好的解释,这里从底层逻辑剖析一下此函数的使用
1. 自定义MultiHeadAttention函数
在介绍pytorch官方库之前,我们先手写一个相同的函数大概了解一下其运行逻辑
这里我们以NLP的输入为例,输入为[seq_length, batch_size, d_model],即第一维是词的个数,第二维是batch,第三维度是每个词的维度。例如我们有一个句子I love you,那么其输入为[3, 1, 6],即三个词,每个词的维度为6。对于CV也是相同的道理,第一维度是图像patch的数量,第二维度是batch,第三维度是每个patch的维度。
1 | import torch |
上述代码中,MultiHeadAttention一共有四个参数,三个Linear获取query, key, value,最后一个Linear将attention后的结果映射到与输入相同的维度。这里我们的输入以及qkv的维度都相同,所以Linear只是增加了增加了网络的复杂度
2. nn.MultiHeadAttention函数
在了解MultiHeadAttention函数的底层逻辑之后,我们再来看看pytorch官方库是如何实现的
首先给出官方文档的解释
1 | import torch |
首先nn.MultiHeadAttention中有四个参数,in_proj_weight,in_proj_bias,out_proj.weight,out_proj.bias,其中前两个参数是用来获取query, key, value的,后两个参数是在attention之后再通过线性层将output映射到与输入相同维度的,这几点和第一部分完全相同,可以对照理解,然后我们讲解上述代码做了什么事情
- 修改了
in_proj_weight和in_proj_bias参数,我们将query, key, value的转换矩阵分别全部初始化为1,2,3,同时将bias全部初始化为0 - 修改了
out_proj.weight和out_proj.bias参数,我们将output的转换矩阵全部初始化为1,bias全部初始化为0,表示在attention之后的结果如果和输入维度不一样,会将其映射到和输入维度相同的维度
其中第一步参考下图理解
第二步参考下图理解,在第一步结束之后,把不同的head concat起来,然后通过线性层映射到和输入维度相同的维度
从上面对比可以看出nn.MultiHeadAttention并没有进行LayerNorm操作,只是进行多头注意力机制