1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
| class _NonLocalBlockND(nn.Module): def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=True, bn_layer=True): super(_NonLocalBlockND, self).__init__()
assert dimension in [1, 2, 3]
self.dimension = dimension self.sub_sample = sub_sample
self.in_channels = in_channels self.inter_channels = inter_channels
if self.inter_channels is None: self.inter_channels = in_channels // 2 if self.inter_channels == 0: self.inter_channels = 1
conv_nd = nn.Conv2d max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) bn = nn.BatchNorm2d
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if bn_layer: self.W = nn.Sequential( conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), bn(self.in_channels) ) nn.init.constant_(self.W[1].weight, 0) nn.init.constant_(self.W[1].bias, 0) else: self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) nn.init.constant_(self.W.weight, 0) nn.init.constant_(self.W.bias, 0)
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
self.concat_project = nn.Sequential( nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False), nn.ReLU() )
if sub_sample: self.g = nn.Sequential(self.g, max_pool_layer) self.phi = nn.Sequential(self.phi, max_pool_layer)
def forward(self, x, return_nl_map=False): ''' :param x: (b, c, t, h, w) :param return_nl_map: if True return z, nl_map, else only return z. :return: '''
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1) g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1) phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1)
h = theta_x.size(2) w = phi_x.size(3) theta_x = theta_x.repeat(1, 1, 1, w) phi_x = phi_x.repeat(1, 1, h, 1)
concat_feature = torch.cat([theta_x, phi_x], dim=1) f = self.concat_project(concat_feature) b, _, h, w = f.size() f = f.view(b, h, w)
N = f.size(-1) f_div_C = f / N
y = torch.matmul(f_div_C, g_x) y = y.permute(0, 2, 1).contiguous() y = y.view(batch_size, self.inter_channels, *x.size()[2:]) W_y = self.W(y) z = W_y + x
if return_nl_map: return z, f_div_C return z
|