nn.functional.fold/unfold

nn.functional.fold/unfold

官方文档:fold, unfold

作用:fold和unfold的作用恰好相反,unfold是用一个滑窗来提取图像中的像素值,类似于卷积操作,但是只提取不计算,fold恰好相反将滑窗提取的值返回为一个图像

nn.functional.unfold(input, kernel_size, dilation=1, padding=0, stride=1)

  • input: 输入tensor
  • kernel_size: 提取时的滑窗大小
  • dilation: 滑窗是否有空洞
  • padding: 是否对原图进行填充
  • stride: 滑窗移动的步长

下面举一个例子直观解释

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
x = torch.Tensor([[[[  1,  2,  3,  4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[ 13, 14, 15, 16]]]])
x = F.unfold(x, kernel_size=(2, 2), padding=0, stride=2)
print(x)
print(x.size())

'''
tensor([[[ 1., 3., 9., 11.],
[ 2., 4., 10., 12.],
[ 5., 7., 13., 15.],
[ 6., 8., 14., 16.]]])
torch.Size([1, 4, 4])
'''

执行过程很简单,用一个的窗在图上滑动,步长为2,第一次覆盖的内容为1256,第二次为3478,以此类推,每次滑窗的结果用一个列向量表示,列数就是滑窗提取的次数。如果我们要得到每次滑窗的结果,例如第一次提取的结果,用表达式x[:,:,0]即可

nn.functional.fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1)

  • input: 输入tensor
  • output_size: 输出图像的大小(必须指定)
  • kernel_size: 在图像中填充的形状
  • dilation: 滑窗是否有空洞
  • padding: 是否对原图进行填充
  • stride: 存放窗的tensor时移动的步长

网上很少有讲这个函数的,都说是unfold的逆过程,我们依然用几个例子来对其进行详细的解释

1. 第一个例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
x = torch.Tensor([[[[  1,  2,  3,  4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[ 13, 14, 15, 16]]]])
x = F.unfold(x, kernel_size=(2, 2), padding=0, stride=2)
x = F.fold(x, output_size=(4,4), kernel_size=(2,2), padding=0, stride=2)

'''
tensor([[[ 1., 3., 9., 11.],
[ 2., 4., 10., 12.],
[ 5., 7., 13., 15.],
[ 6., 8., 14., 16.]]])
torch.Size([1, 4, 4])
tensor([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]])
torch.Size([1, 1, 4, 4])
'''

fold函数是如何执行的呢,他会提取unfold函数的每一列,首先提取1256这一列,然后根据kernel_size的大小将1256重新resize并填到output的第一个位置,如下

1
2
3
4
[[ 1.,  2.,  0.,  0.],
[ 5., 6., 0., 0.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.]]

随后提取第二列2.6.10.14,resize为的形状,根据步长为2添加到output的下一个位置,并以此类推

1
2
3
4
[[ 1.,  2.,  3.,  4.],
[ 5., 6., 7., 8.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.]]

注意:output,kernel以及stride必须满足一定的关系(参考文档)

知道原理以后我们可以自由操作上述tensor,但是注意,如果步长等设置不合适的话,最后的结果是有overlap的,下面我们展示两个例子

2. 第二个例子

自由操作tensor

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
x = torch.Tensor([[[[  1,  2,  3,  4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[ 13, 14, 15, 16]]]])
x = F.unfold(x, kernel_size=(2, 2), padding=0, stride=2)
x = F.fold(x, output_size=(4,4), kernel_size=(4,1), padding=0, stride=1)

'''
tensor([[[ 1., 3., 9., 11.],
[ 2., 4., 10., 12.],
[ 5., 7., 13., 15.],
[ 6., 8., 14., 16.]]])
torch.Size([1, 4, 4])

# tensor又变回了原来的样子
tensor([[[[ 1., 3., 9., 11.],
[ 2., 4., 10., 12.],
[ 5., 7., 13., 15.],
[ 6., 8., 14., 16.]]]])
torch.Size([1, 1, 4, 4])
'''

overlap的情况

根据上述讲的可以自己推一下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
x = torch.Tensor([[[[  1,  2,  3,  4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[ 13, 14, 15, 16]]]])
x = F.unfold(x, kernel_size=(2, 2), padding=0, stride=2)
x = F.fold(x, output_size=(3,3), kernel_size=(2,2), padding=0, stride=1)

'''
tensor([[[ 1., 3., 9., 11.],
[ 2., 4., 10., 12.],
[ 5., 7., 13., 15.],
[ 6., 8., 14., 16.]]])
torch.Size([1, 4, 4])
tensor([[[[ 1., 5., 4.],
[14., 34., 20.],
[13., 29., 16.]]]])
torch.Size([1, 1, 3, 3])
'''

3. kernel size小于列向量的情况

上面讲了,fold每次都会对列向量进行提取,之前的例子都是kernel size等于列向量,如果我们的kernel size小于列向量就会出现以下情况

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
x = torch.Tensor([[[[  1,  2,  3,  4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[ 13, 14, 15, 16]]]])
x = F.unfold(x, kernel_size=(2, 2), padding=0, stride=2)
x = F.fold(x, output_size=(2,2), kernel_size=(1,1), padding=0, stride=1)

'''
tensor([[[ 1., 3., 9., 11.],
[ 2., 4., 10., 12.],
[ 5., 7., 13., 15.],
[ 6., 8., 14., 16.]]])
torch.Size([1, 4, 4])
tensor([[[[ 1., 3.],
[ 9., 11.]],

[[ 2., 4.],
[10., 12.]],

[[ 5., 7.],
[13., 15.]],

[[ 6., 8.],
[14., 16.]]]])
torch.Size([1, 4, 2, 2])
'''

解释一下,我们第一次提取的应该是1256,但是由于我们的kernel太小了,,只能提取一个元素,因此就是1,我们的output size是,步长为1,所以第一次提取的结果如下

1
2
[[  1,  0],
[ 0, 0]]

第二次提取时,就需要移动了,提取的不是列向量中的2,而是横向移动的3,接着放到刚才那个元素后面

1
2
[[  1,  3],
[ 0, 0]]

之后的过程以此类推,直到我们提取到11,这时我们的行向量提取完了,但是列向量没有,所以我们从第二列开始重复刚才的过程即可,可以看到最终我们输出向量大小为[1,4,2,2],4就是我们提取了4次行向量,两个2就是每次提取的大小(即output size)

最后加一个复杂的具有padding的例子

padding就是在对tensor进行操作之前在tensor四周补0或其他的值。例子中仅对unfold进行padding,如果对fold进行padding也同理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
x = torch.Tensor([[[[  1,  2,  3,  4,  5,  6, 7,  8],
[ 9, 10, 11, 12, 13, 14, 15, 16],
[ 17, 18, 19, 20, 21, 22, 23, 24],
[ 25, 26, 27, 28, 29, 30, 31, 32],
[ 33, 34, 35, 36, 37, 38, 39, 40],
[ 41, 42, 43, 44, 45, 46, 47, 48],
[ 49, 50, 51, 52, 53, 54, 55, 56],
[ 57, 58, 59, 60, 61, 62, 63, 64]]]])
x = F.unfold(x, kernel_size=(6,6), padding=1, stride=4)
x = F.fold(x, output_size=(12,12), kernel_size=(6,6), padding=0, stride=6)

'''
tensor([[[ 0., 0., 0., 28.],
[ 0., 0., 25., 29.],
[ 0., 0., 26., 30.],
[ 0., 0., 27., 31.],
[ 0., 0., 28., 32.],
[ 0., 0., 29., 0.],
[ 0., 4., 0., 36.],
[ 1., 5., 33., 37.],
[ 2., 6., 34., 38.],
[ 3., 7., 35., 39.],
[ 4., 8., 36., 40.],
[ 5., 0., 37., 0.],
[ 0., 12., 0., 44.],
[ 9., 13., 41., 45.],
[10., 14., 42., 46.],
[11., 15., 43., 47.],
[12., 16., 44., 48.],
[13., 0., 45., 0.],
[ 0., 20., 0., 52.],
[17., 21., 49., 53.],
[18., 22., 50., 54.],
[19., 23., 51., 55.],
[20., 24., 52., 56.],
[21., 0., 53., 0.],
[ 0., 28., 0., 60.],
[25., 29., 57., 61.],
[26., 30., 58., 62.],
[27., 31., 59., 63.],
[28., 32., 60., 64.],
[29., 0., 61., 0.],
[ 0., 36., 0., 0.],
[33., 37., 0., 0.],
[34., 38., 0., 0.],
[35., 39., 0., 0.],
[36., 40., 0., 0.],
[37., 0., 0., 0.]]])
torch.Size([1, 36, 4])
tensor([[[[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 1., 2., 3., 4., 5., 4., 5., 6., 7., 8., 0.],
[ 0., 9., 10., 11., 12., 13., 12., 13., 14., 15., 16., 0.],
[ 0., 17., 18., 19., 20., 21., 20., 21., 22., 23., 24., 0.],
[ 0., 25., 26., 27., 28., 29., 28., 29., 30., 31., 32., 0.],
[ 0., 33., 34., 35., 36., 37., 36., 37., 38., 39., 40., 0.],
[ 0., 25., 26., 27., 28., 29., 28., 29., 30., 31., 32., 0.],
[ 0., 33., 34., 35., 36., 37., 36., 37., 38., 39., 40., 0.],
[ 0., 41., 42., 43., 44., 45., 44., 45., 46., 47., 48., 0.],
[ 0., 49., 50., 51., 52., 53., 52., 53., 54., 55., 56., 0.],
[ 0., 57., 58., 59., 60., 61., 60., 61., 62., 63., 64., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])
torch.Size([1, 1, 12, 12])
'''