non-local neural networks

原文地址:https://arxiv.org/pdf/1711.07971.pdf
代码地址:https://github.com/facebookresearch/video-nonlocal-net
PyTorch实现:https://github.com/AlexHex7/Non-local_pytorch

原理

原理有点类似于Transformer的self-attention,如图所示,我们输入的维度为,通过1*1的卷积我们得到的维度为的维度也为,将二者矩阵相乘得到,这里和Self-attention很相似?在Transform部分我会讨论

下面看一下其执行过程:对于RGB图片我们首先卷积映射到高维度特征,记下channel为C,然后我们使用1*1的卷积核对其进行操作进行降维为原来的二分之一,此时的1*1卷积就类似于transformer中的划分patch然后patch embedding,就得到了,此时将得到的每一维度的特征flatten,将矩阵相乘,得到attention map,然后与g相乘即可,可以看到和transformer中的single self attention几乎是一致的,只不过得到特征图是用的1*1卷积而不是linear

下面是更细的整体执行流程

代码讲解

Non-local的PyTorch实现有四个版本,这里描述Embed Gaussian版本,我们只对二维图片进行处理,所以这里代码把1D和3D的给删除,如下所示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class _NonLocalBlockND(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=True, bn_layer=True):
super(_NonLocalBlockND, self).__init__()

assert dimension in [1, 2, 3]

self.dimension = dimension
self.sub_sample = sub_sample

self.in_channels = in_channels
self.inter_channels = inter_channels

if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1

conv_nd = nn.Conv2d
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
bn = nn.BatchNorm2d

self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)

if bn_layer:
self.W = nn.Sequential(
conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
bn(self.in_channels)
)
nn.init.constant_(self.W[1].weight, 0)
nn.init.constant_(self.W[1].bias, 0)
else:
self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0)
nn.init.constant_(self.W.weight, 0)
nn.init.constant_(self.W.bias, 0)

self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)

self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)

self.concat_project = nn.Sequential(
nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),
nn.ReLU()
)

if sub_sample:
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)

def forward(self, x, return_nl_map=False):
'''
:param x: (b, c, t, h, w)
:param return_nl_map: if True return z, nl_map, else only return z.
:return:
'''

batch_size = x.size(0)

g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)

# (b, c, N, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)
# (b, c, 1, N)
phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1)

h = theta_x.size(2)
w = phi_x.size(3)
theta_x = theta_x.repeat(1, 1, 1, w)
phi_x = phi_x.repeat(1, 1, h, 1)

concat_feature = torch.cat([theta_x, phi_x], dim=1)
f = self.concat_project(concat_feature)
b, _, h, w = f.size()
f = f.view(b, h, w)

N = f.size(-1)
f_div_C = f / N

y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x

if return_nl_map:
return z, f_div_C
return z
Error: API rate limit exceeded for 18.204.137.164. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)