nn.LayerNorm实现及原理

1. nn.LayerNorm函数

nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None)

  • normalized_shape:归一化的维度
  • eps:加在方差上的数字,避免分母为0,default=1e-5
  • elementwise_affine:bool,True的话会有一个默认的affine参数,即上述公式中的,前者开始为1,后者为0,二者均可学习随着训练过程而变化,default=True

2. LayerNorm在Transformer中的应用

在transformer中一般采用LayerNorm,LayerNorm也是归一化的一种方法,与BatchNorm不同的是它是对每单个batch进行的归一化,而batchnorm是对所有batch一起进行归一化的

在Transformer中,给定输入tensor[seq_length, batch_size, d_model],其中seq_length为序列长度,batch_size为batch大小,d_model为embedding的维度,如下所示,我们有多个batch,其中第一个batch的序列长度为3,包括“I, Love, You”,他们的词向量维度均为6,在LayerNorm的时候,分别对“I”,“Love”,“You”向量进行归一化,即相同颜色的数字

2.1 LayerNorm的官方实现

下面代码展示了nn.LayerNorm的官方使用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch.nn as nn
import torch

input = torch.arange(1, 19).view(3, 1, 6).type(torch.float32)
print(input)

'''
tensor([[[ 1., 2., 3., 4., 5., 6.]],

[[ 7., 8., 9., 10., 11., 12.]],

[[13., 14., 15., 16., 17., 18.]]])
'''

# 官方nn.LayerNorm实现
norm = nn.LayerNorm(6)
output = norm(input)
print(output)

'''
tensor([[[-1.4638, -0.8783, -0.2928, 0.2928, 0.8783, 1.4638]],

[[-1.4638, -0.8783, -0.2928, 0.2928, 0.8783, 1.4638]],

[[-1.4638, -0.8783, -0.2928, 0.2928, 0.8783, 1.4638]]],
grad_fn=<NativeLayerNormBackward0>)
'''

# 手动计算验证
mean = torch.mean(input, dim=-1, keepdim=True)
# 这里方差计算是除了N,不是N-1
std = torch.std(input, correction=0, dim=-1, keepdim=True)
output = (input - mean) / (std + 1e-5)
print(output)

'''
tensor([[[-1.4638, -0.8783, -0.2928, 0.2928, 0.8783, 1.4638]],

[[-1.4638, -0.8783, -0.2928, 0.2928, 0.8783, 1.4638]],

[[-1.4638, -0.8783, -0.2928, 0.2928, 0.8783, 1.4638]]])
'''

2.2 LayerNorm的自定义实现

下面代码展示了nn.LayerNorm的自定义实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch.nn as nn
import torch

class LayerNorm(nn.Module):
"Construct a layernorm module."

def __init__(self, features, eps=1e-5):
super(LayerNorm, self).__init__()
self.a_2 = nn.Parameter(torch.ones(features))
self.b_2 = nn.Parameter(torch.zeros(features))
self.eps = eps

def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, correction=0, keepdim=True)
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2


input = torch.arange(1, 19).view(3, 1, 6).type(torch.float32)
print(input)
'''
tensor([[[ 1., 2., 3., 4., 5., 6.]],

[[ 7., 8., 9., 10., 11., 12.]],

[[13., 14., 15., 16., 17., 18.]]])
'''

norm = LayerNorm(6)
output = norm(input)
print(output)
'''
tensor([[[-1.4638, -0.8783, -0.2928, 0.2928, 0.8783, 1.4638]],

[[-1.4638, -0.8783, -0.2928, 0.2928, 0.8783, 1.4638]],

[[-1.4638, -0.8783, -0.2928, 0.2928, 0.8783, 1.4638]]],
grad_fn=<AddBackward0>)
'''

mean = torch.mean(input, dim=-1, keepdim=True)
std = torch.std(input, correction=0, dim=-1, keepdim=True)
output = (input - mean) / (std + 1e-5)
print(output)
'''
tensor([[[-1.4638, -0.8783, -0.2928, 0.2928, 0.8783, 1.4638]],

[[-1.4638, -0.8783, -0.2928, 0.2928, 0.8783, 1.4638]],

[[-1.4638, -0.8783, -0.2928, 0.2928, 0.8783, 1.4638]]])
'''

3. nn.LayerNorm的底层逻辑

根据上述例子我们可以了解到LayerNorm的归一化底层逻辑,给定LayerNorm归一化维度,他会将输入tensor的最后几个维度进行整体归一化。什么意思呢?假设我们的输入为(1, 3, 4, 4)的变量,并对其进行LayerNorm,这里我们展示两个例子

注意:这里的例子只是帮助理解LayerNorm函数的用法,并不是说四维tensor就要按照下面两种方式处理,正常来说,CNN中很少用LayerNorm

如下图所示,左边为第一种归一化方法,对所有channel所有像素计算;右边为第二种归一化方法,对所有channel的每个像素分别计算

  • 计算一个batch中所有channel中所有参数的均值和方差,然后进行归一化,即(3, 5, 5)
  • 计算一个batch中所有channel中的每一个参数的均值和方差进行归一化,即(3, 1, 1),计算25次

3.1 第一种计算

直接给出计算代码

注意:输入为(1, 3, 4, 4),layernorm的normalized_shape为[3, 5, 5],也就是说对后三维度进行归一化操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch.nn as nn
import torch


input = torch.arange(1, 49).view(3, 4, 4).type(torch.float32)
input = input.unsqueeze(0)
'''
tensor([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]],

[[17., 18., 19., 20.],
[21., 22., 23., 24.],
[25., 26., 27., 28.],
[29., 30., 31., 32.]],

[[33., 34., 35., 36.],
[37., 38., 39., 40.],
[41., 42., 43., 44.],
[45., 46., 47., 48.]]]])
'''


# 直接使用nn.LayerNorm函数计算
norm = nn.LayerNorm([3, 4, 4])
print(norm(input))
'''
tensor([[[[-1.6963, -1.6242, -1.5520, -1.4798],
[-1.4076, -1.3354, -1.2632, -1.1910],
[-1.1189, -1.0467, -0.9745, -0.9023],
[-0.8301, -0.7579, -0.6858, -0.6136]],

[[-0.5414, -0.4692, -0.3970, -0.3248],
[-0.2526, -0.1805, -0.1083, -0.0361],
[ 0.0361, 0.1083, 0.1805, 0.2526],
[ 0.3248, 0.3970, 0.4692, 0.5414]],

[[ 0.6136, 0.6858, 0.7579, 0.8301],
[ 0.9023, 0.9745, 1.0467, 1.1189],
[ 1.1910, 1.2632, 1.3354, 1.4076],
[ 1.4798, 1.5520, 1.6241, 1.6963]]]],
grad_fn=<NativeLayerNormBackward0>)
'''

# 手动计算
mean = torch.mean(input)
std = torch.std(input, correction=0)
x = (input-mean)/(std+1e-5)
print(x)
'''
tensor([[[[-1.6963, -1.6241, -1.5520, -1.4798],
[-1.4076, -1.3354, -1.2632, -1.1910],
[-1.1189, -1.0467, -0.9745, -0.9023],
[-0.8301, -0.7579, -0.6858, -0.6136]],

[[-0.5414, -0.4692, -0.3970, -0.3248],
[-0.2526, -0.1805, -0.1083, -0.0361],
[ 0.0361, 0.1083, 0.1805, 0.2526],
[ 0.3248, 0.3970, 0.4692, 0.5414]],

[[ 0.6136, 0.6858, 0.7579, 0.8301],
[ 0.9023, 0.9745, 1.0467, 1.1189],
[ 1.1910, 1.2632, 1.3354, 1.4076],
[ 1.4798, 1.5520, 1.6241, 1.6963]]]])
'''

当然如果要灵活的进行操作,可以将tensor提前resize以下,这样LayerNorm就不需要传入list列表了,比如这里将输入resize为[1, 3*4*4],这样初始化LayerNorm(3*4*4)即可,等操作完成后再resize回来

3.2 第二种计算

直接给出计算代码

注意:我们的输入是(1, 3, 4, 4),如果要完成第二种方法,我们layernorm只需要提供一个参数,即norm = nn.LayerNorm(3),但是如果只提供一个参数,默认为对最后一维进行归一化,所以我们需要将输入进行变化,即变为(1, 4, 4, 3)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch.nn as nn
import torch


input = torch.arange(1, 49).view(3, 4, 4).type(torch.float32)
input = input.unsqueeze(0) # [1, 3, 5, 5]

# [1, 3, 5, 5] -> [1, 5, 5, 3]
input = input.permute(0, 2, 3, 1).contiguous()
print(input) # [1, 5, 5, 3]
'''
tensor([[[[ 1., 17., 33.],
[ 2., 18., 34.],
[ 3., 19., 35.],
[ 4., 20., 36.]],

[[ 5., 21., 37.],
[ 6., 22., 38.],
[ 7., 23., 39.],
[ 8., 24., 40.]],

[[ 9., 25., 41.],
[10., 26., 42.],
[11., 27., 43.],
[12., 28., 44.]],

[[13., 29., 45.],
[14., 30., 46.],
[15., 31., 47.],
[16., 32., 48.]]]])
'''

# LayerNorm函数计算
norm = nn.LayerNorm(3)
print(norm(input))
'''
tensor([[[[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247]],

[[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247]],

[[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247]],

[[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247]]]], grad_fn=<NativeLayerNormBackward0>)
'''


# 手动计算
mean = torch.mean(input, dim=3, keepdim=True)
std = torch.std(input, correction=0, dim=3, keepdim=True)
x = (input-mean)/(std+1e-5)
print(x)
'''
tensor([[[[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247]],

[[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247]],

[[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247]],

[[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247]]]])
'''

# 最后将输出resize回输入维度 [1, 5, 5, 3] -> [1, 3, 5, 5]
x = x.permute(0, 3, 1, 2)
print(x) # [1, 3, 5, 5]
'''
tensor([[[[-1.2247, -1.2247, -1.2247, -1.2247],
[-1.2247, -1.2247, -1.2247, -1.2247],
[-1.2247, -1.2247, -1.2247, -1.2247],
[-1.2247, -1.2247, -1.2247, -1.2247]],

[[ 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000]],

[[ 1.2247, 1.2247, 1.2247, 1.2247],
[ 1.2247, 1.2247, 1.2247, 1.2247],
[ 1.2247, 1.2247, 1.2247, 1.2247],
[ 1.2247, 1.2247, 1.2247, 1.2247]]]])
'''