@[TOC](Swin Transformer详解)
论文地址:https://arxiv.org/pdf/2103.14030.pdf
代码地址:https://github.com/microsoft/Swin-Transformer
本文一共分为三个部分,首先介绍Swin Transformer的整体架构,随后会介绍每个模块的作用,中间会穿插部分代码。本文的主要目的还是希望能够将Swin Transformer解释清楚,然后结合官方代码来理解
一、Overall Architecture
首先给出论文中的Swin Transformer架构图

左边是Swin Transformer的全局架构,它包含Patch Partition、Linear Embedding、Swin Transformer Block、Patch Merging四大部分,这四大部分我们之后会进行详细的介绍
右边是Swin Transformer Block结构图,这是两个连续的Swin Transformer Block块,一个是W-MSA,一个是SW-MSA,也就是说根据Swin的Tiny版本,图中的Swin Transformer Block块为[2, 2, 6, 2],相对应的attention为:stage1 W-MSA-->SW-MSA
– stage2 W-MSA-->SW-MSA
– stage3 W-MSA-->SW-MSA-->W-MSA-->SW-MSA-->W-MSA-->SW-MSA
– stage4 W-MSA-->SW-MSA
二、Swin Transformer
下面的维度等均是基于Swin-T
版本
1. Patch Partition & Linear Embedding
输入为(B, 3, 224, 224)
输出为(B, 96, 56, 56) —> (B, 96, 224/4=56, 224/4=56)
这两步在论文中其实就是一步实现,我们先来看paper中的解释:
- Patch Partition,这一步是将输入的(H, W, 3)的图片分成(4, 4)的小块,分块后的图片大小为(H/4, W/4, 48)也就是文中所给的维度
- Linear Embedding,在Tiny版本中,将分块后的图像映射到96维
在真正实现的时候paper使用了PatchEmbed函数将这两步结合起来,实际上也就是用了一个卷积的操作,卷积核大小为(4, 4),步长为4:nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
下面图示此过程

2. Basic Layer
在官方的代码库中,将Swin Transformer Block和Patch Merging合并成了一个,叫做Basic Layer,下面我们分别介绍这两者
Swin Transformer Block
输入为(B, 3136, 96)
输出为(B, 3136, 96)
就是把上一步的(4, 96, 56, 56)后两维度合并变为(4, 96, 3136),然后后两维互换(4, 3136, 96)
Swin Transformer Block的输入输出不变,每两个连续Block为一组,即一个Window Multi-head Self-Attention和一个Shifted Window Multi-head Self-Attention
下面是paper中的Swin Transformer Block
示例图

从图中我们可以看出每两个连续Block块有四小步:
1. 第一个Block
- 首先特征图经过
Layer Norm
层,经过W-MSA
,然后进行跳跃连接 - 连接后的特征图再次经过
Layer Norm
层,经过全连接层MLP
,然后进行跳跃连接
2. 第二个Block
- 首先特征图经过
Layer Norm
层,经过SW-MSA
,然后进行跳跃连接 - 连接后的特征图再次经过
Layer Norm
层,经过全连接层MLP
,然后进行跳跃连接
从上面四步可以看出Swin Transformer Block
清晰的执行步骤,其中比较难理解的是W-MSA
和SW-MSA
,下面我们详细介绍二者,并介绍由二者引出的一些细节
(1)first block
包含两个主要模块,W-MSA和MLP
输入为(B, 3136, 96)
输出为(B, 3136, 96)
W-MSA
window partition
W-MSA在第一个block中,这一步没有滑动窗,输入为(B, 3136, 96),为了后面的sefl-attention操作,需要将特征图划分为一个个窗口的形式,首先经历了一个window partition操作,变为(64B, 7, 7, 96)
怎么计算的呢?输入为batch=B,3136=56*56,特征图有96个,将每个特征图56*56分为7*7的窗口,一共能分8*8=64个,乘上之前的B就是64B了,就是说将特征图分为(7, 7)的小窗,然后把所有的小窗拿出来一共有64B个,示例图如下

==为什么要进行window partition?在Vision Transformer中,我们将图片分成了一个个patch(也就是左边的图),在进行MSA时,任何一个patch都要与其他所有的patch都进行attention,当patch的大小固定时,计算量与图片的大小成平方增长。Swin Transformer中采用了W-MSA,也就是window的形式,不同的window包含了相同数量的patch,只对window内部进行MSA,当图片大小增大时,计算量仅仅是呈线性增加(只增加了图片多余部分的计算量,比如之前是224的图像,现在是256的图像,只多了256-224=32像素的计算部分),下面详细介绍window attention部分==
window attention
将窗口分配完成后就可以执行attention操作了,首先我们将维度变为(64B, 49, 96),进行attention操作时,我们需要qkv三个变量,transformer是通过linear函数来实现的:nn.Linear(dim, dim * 3, bias=qkv_bias)
,通过这个函数后,维度变为(64B, 49, 288),qkv分别占三分之一,也就是说qkv分别为(64B, 49, 96),第一个阶段的head为3,维度划分为(64B, 3, 49, 32)
此时qkv的值如下所示,这就是进行attention时qkv的维度
- q: (64B, 3, 49, 32)
- k: (64B, 3, 49, 32)
- v: (64B, 3, 49, 32)
接下来就是进行attention操作,熟悉transformer的同学肯定很容易理解
注意这里加了一个偏置B,在最后会详细介绍相对位置偏置(Relative Position Bias)的原理
window reverse
所有attention步骤执行完之后就可以回到attention之前的维度(64B, 7, 7, 96),然后我们经过一个window reverse操作就可以回到window partition之前的状态了,即(B, 56, 56, 96)。window reverse就是window partition的逆过程
总结:这里总结一下W-MSA所做的事情,首先进行window partition操作,维度从(B, 3136, 96)也就是(B, 56, 56, 96)变为(64B, 7, 7, 96);随后进行attention操作,先经过一个线性层维度变为三倍来为qkv分别赋值(64B, 49, 96*3): qkv(64B, 49, 96),随后根据multi-head操作在将qkv分别分成三份,(64B, 3, 49, 32),最后进行attention操作(即上面的公式),然后通过window reverse回到最初的状态(B, 56, 56, 96),也就是(B, 3136, 96),下面图示了这一阶段的过程

MLP
输入为(4, 3136, 96)
输出为(4, 3136, 96)
再经过第二个Block之前要先经过一个MLP,其中结构为
Linear(96, 96*4)
——GELU()
——Linear(96*4, 96)
——Dropout
最终维度并不发生变化
(2)second block
包含两个主要模块,SW-MSA和MLP
输入为(4, 3136, 96)
输出为(4, 3136, 96)
与第一个Block唯一不同的地方就是SW-MSA模块,所以这里仅讲解此模块
SW-MSA
与W-MSA不同的地方在于这个模块存在滑动,所以叫做shifted window,滑动的距离为win_size//2
在这里也就是7//2=3
,这里用image(4, 4)
win(2, 2)
shift=1
来图示他的shift以及mask机制
这里先给出Github上有助于理解此机制的提问:链接

为什么要用mask机制呢,Swin Transformer与Vision Transformer相比虽然降低了计算量,但缺点是同一个window里面的patch可以交互,window与window之间无法交互,所以考虑滑动窗的方法,如上图所示,滑动过后为了保证图片的完整性,我们将上面和左边的图补齐到右边,这又带来了一个缺点:图片的右端和补齐的图片本身并不是相邻的,所以无法交互,解决办法就是mask
Swin Transformer的mask机制是说,如果相互交互的patch属于同一个区域(对应于上图的颜色),那么就可以正常交互,如果不是同一个区域(对应于上图的不同颜色),那么他们交互之后就需要加上一个很大的负值,这样通过softmax层之后本来不能交互的那个像素就变成0了,这就是mask机制
这里附上Github上讨论的一个源码,由此可以直接看到mask是如何运行的,这个代码与我上述的图是对应的
1 | import torch |
Patch Merging
在每个Stage结束的阶段都有一个Patch Merging
的过程,这个过程会让输入进行降维,同时通道变为原来的二倍,用一个图来清晰的展示此过程,图示如下

上面说到过Swin的作用是使得patch交互的区域变大,另一种使其变大的方法就是这里提到的Patch Merging,在每个阶段结束之后,将特征图的维度减半,channel加倍,在保持patch和window不变的情况下相当于变相提高了patch和window的感受野,使其效果更好
到这里Swin Transformer的一个stage就已经讲完了,其余的Stage和上面讲述的完全一致,为了再次强化Swin Transformer的整个流程,下面是整个流程展示,其中加粗部分为我们已经走过的流程(这里依然是Swin-Tiny版本)
input-->patch partition-->linear embedding
stage1 W-MSA-->MLP-->SW-MSA-->MLP
stage2 W-MSA-->MLP-->SW-MSA-->MLP
stage3 W-MSA-->MLP-->SW-MSA-->MLP
*3
stage4 W-MSA-->MLP-->SW-MSA-->MLP-->tail process
三、Supplement
Relative Position Bias
到这里整个Swin Transformer
就已经讲完了,还记得attention中加了一个bias B
吗,这里对其进行讲解,依旧取win=2
,如下所示

这里的相对位置偏置这样理解,在窗口中任意选定一个坐标,遵循左+右-上+下-
的原则,可以发现当我们将左上角的值为(0, 0)
时,他右边的位置为(0, -1)
减了1,下面的位置为(-1, 0)
也减了1,同理将其他位置设为(0, 0)
时,结果分别如图所示
然后我们将其展开,执行:行列分别加M-1=2-1=1
,行标乘2M-1=3
,最终可以得到下图,然后需要注意的是最大值为8,也就是说一共有9个索引,为什么有四个像素,按理来说为4*4=16
个位置,只有9个索引呢?这是因为是相对位置编码位置有重复,又因为win=2
,所以行和列的索引均为[-1, 1]
,一共3*3=9
种组合,即九个相对位置索引,因此相对位置索引表一共有9个数字,如下图所示
其中上面是索引表(9个数),下面是索引后的结果

为了更清晰的认识相对位置偏置,这里给出一个简单的example
1 | # relative_position_bias_table (1, 9) |
到这里Swin Transformer
就讲完啦,但是因为写的比较仓促有一些地方讲的不够细致,还有关于FLOPs运算的细节没有讲到,后面有时间会再补充~