pytorch的几种conv方法

@TOC

Conv

首先是常规卷积,假设我们有一张的特征图,现在想得到一张的图,如果直接使用卷积操作,大卷积核(包含channel,3维)一共有10个,每个大小为。代码及计算过程如下图所示

1
2
3
4
5
6
7
conv = nn.Conv2d(6, 10, kernel_size=2, stride=1, padding=0, bias=False, groups=1)
input = torch.ones((1, 6, 4, 4))
output = conv(input)
print(output.size())
'''
torch.Size([1, 10, 3, 3])
'''

Normal Convolution

参数和计算量

参数:10个输出channel,6个输入channel,卷积核大小为2 * 2
计算量:10个输出channel,6个输入channel,卷积核大小为2 * 2,输出图为3 * 3

参数:
计算量:

Group Conv

分组卷积可降低参数量,我们假设一张的特征图,现在想得到一张的图,设分组卷积数为2,因此每个大卷积核的大小为,一共有个大卷积核。代码及计算过程如下图所示

1
2
3
4
5
6
7
group_conv = nn.Conv2d(6, 10, kernel_size=2, stride=1, padding=0, bias=False, groups=2)
input = torch.ones((1, 6, 4, 4))
output = group_conv(input)
print(output.size())
'''
torch.Size([1, 10, 3, 3])
'''

Group Convolution

参数和计算量

参数:10个输出channel(分组为2,每组5 channel),6个输入channel(分组为2,每组3 channel),大卷积核一共分为两组,每组大小为3 * 2 * 2,卷积核大小为2 * 2
计算量:10个输出channel,6个输入channel,卷积核大小为2 * 2,输出图为3 * 3

参数:
计算量:

Depthwise Separable Convolution

depthwise separable convolution是Google在2017年提出的arXiv

这个模型为MobileNet,主要是在显著降低参数和计算量的情况下保证性能,depthwise separable convolution一共分为两步:depthwise conv以及pointwise conv

1. Depthwise Conv

depthwise中每个卷积核只负责一个通道,卷积只能在二维平面内进行,因此他没有办法增加通道数

继续上面的例子,我们假设一张的特征图,因为depthwise没办法增加通道数,所以我们只能得到一张的图

Depthwise Convolution

参数和计算量

参数:6个输出channel,6个输入channel,大卷积核大小为1 * 2 * 2,卷积核大小为2 * 2
计算量:6个输出channel,6个输入channel,卷积核大小为2 * 2,输出图为3 * 3

参数:
计算量:

2. Pointwise Conv

因为我们想获得的特征图,但是目前经过depthwise我们得到了的特征图,现在我们用的核来进行pointwise操作,每个卷积核的大小为,一共有10个

Pointwise Convolution

参数和计算量

参数:10个输出channel,6个输入channel,大卷积核大小为6 * 1 * 1,卷积核大小为1 * 1
计算量:10个输出channel,6个输入channel,卷积核大小为2 * 2,输出图为3 * 3

参数:
计算量:

上述两步的代码如下所示

1
2
3
4
5
6
7
8
9
10
depthwise = nn.Conv2d(6, 6, kernel_size=2, stride=1, padding=0, bias=False, groups=6)
pointwise = nn.Conv2d(6, 10, kernel_size=1, stride=1, padding=0, bias=False, groups=1)
input = torch.ones((1, 6, 4, 4))
output = depthwise(input)
print(output.size())
output = pointwise(output)
print(output.size())
'''
torch.Size([1, 6, 3, 3])
torch.Size([1, 10, 3, 3])'''

总结:特征图

model params flops
Normal Conv 240 2160
Group Conv 120 1080
Separable Conv 24+60 216+540
Error: API rate limit exceeded for 52.45.52.34. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)