CGAN实现过程

@TOC

本文用MNIST数据集进行训练,并用图解的方法展示了CGAN与GAN中输入的区别,帮助理解CGAN的运行过程

一、原理

如下图所示,我们在输入噪声z时,额外加上一个限制条件conditionz和c通过生成器G得到生成的图片

二、参数初始化

有了上面的原理解释,我们就可以来初始化我们的参数了,大致可以看出我们有如下几个参数:噪声z,条件c,真实图片x,生成器和判别器的初始化参数

  • G的输入:z_y_vec_
  • D的输入:xy_fill_
  • 模型参数的初始化
  • 测试时用的噪声sample_z_以及对应的标签sample_y_

这里输入的单个噪声维度为z_dim=62,当然这里还有很多其他的初始化,比如optimizer等,因为本文主要介绍模型的的具体执行过程,所以只对变量得初始化做介绍

1. G的输入

  • 输入噪声z:z_: (64, 62)
  • 输入条件c:y_vec_:(64, 10)

最终G的输入:横向拼接z+c (64, 72)

1
2
3
4
5
6
7
8
9
G:
torch.Size([64, 72])
tensor([[0.8920, 0.9742, 0.6876, ..., 0.0000, 0.0000, 0.0000],
[0.5271, 0.6423, 0.7480, ..., 0.0000, 1.0000, 0.0000],
[0.9545, 0.6324, 0.9603, ..., 0.0000, 0.0000, 0.0000],
...,
[0.1931, 0.7773, 0.8154, ..., 0.0000, 0.0000, 0.0000],
[0.0049, 0.7129, 0.3272, ..., 0.0000, 0.0000, 0.0000],
[0.2902, 0.1194, 0.0020, ..., 0.0000, 1.0000, 0.0000]])

2. D的输入

  • 输入真实数据:x: (64, 1, 28, 28)
  • 输入生成数据:G(z):(64, 1, 28, 28)
  • 输入条件:c:y_fill_:(64, 10, 28, 28)

最终D的输入:横向拼接x+c (64, 11, 28, 28),也就是说取batch中的一个值,维度为(1,28, 28),将其作为(11, 28, 28)的第一维,剩下的十维如果标签为0则第二维为全1,剩下的为全0,如果标签为1则第三维为全1,剩下的为全0,以此类推

1
2
3
4
5
6
7
8
9
10
11
12
13
D:
torch.Size([64, 11, 28, 28])
tensor([[[[ 0.1099, -0.5590, 0.9668, ..., 3.0843, 0.6788, -0.4171],
[ 0.8949, -0.3523, -0.4086, ..., -0.8257, -2.1445, 1.0512],
[ 1.5333, -0.0918, -1.1146, ..., -1.1746, -0.4689, 0.3702],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
...,
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],

3. 模型参数初始化

1
2
3
4
5
6
7
8
9
10
11
def initialize_weights(net):
for m in net.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.ConvTranspose2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()

4. 测试噪声

在测试时我们只需要设置G的输入就可以了,也就是说我们需要:

  • 输入噪声z:z_: (100, 62)
  • 输入条件c:y_vec_:(100, 10)

最终G的输入:横向拼接z+c (100, 72)

下面给出代码和输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# fixed noise
sample_z_ = torch.randn((100, 62))
for i in range(10):
sample_z_[i*10] = torch.rand(1, 62)
for j in range(1, 10):
sample_z_[i*10 + j] = sample_z_[i*10]
print(sample_z_)
"""
sample_z_:(100, 62)
0-9: same value
10-19: same value
...
90-99: same value
"""
temp = torch.zeros((10, 1)) # (10,1)---> 0,0,0,0,0,0,0,0,0,0
for i in range(10):
temp[i, 0] = i # (10, 1) ---> 0,1,2,3,4,5,6,7,8,9
# print("temp: ", temp)

temp_y = torch.zeros((100, 1)) #(100,1)---> 0,0,0,0,...,0,0,0,0
for i in range(10): #(100,1)---> 0,1,2,3,...,6,7,8,9
temp_y[i*10: (i+1)*10] = temp
# print("temp_y: ", temp_y)
sample_y_ = torch.zeros((100, 10)).scatter_(1, temp_y.type(torch.LongTensor), 1)
print(sample_y_) #(100,10)
'''
tensor([[0.3944, 0.9880, 0.4956, ..., 0.0602, 0.9869, 0.5094],
[0.3944, 0.9880, 0.4956, ..., 0.0602, 0.9869, 0.5094],
[0.3944, 0.9880, 0.4956, ..., 0.0602, 0.9869, 0.5094],
...,
[0.2845, 0.7694, 0.9878, ..., 0.3211, 0.0242, 0.0332],
[0.2845, 0.7694, 0.9878, ..., 0.3211, 0.0242, 0.0332],
[0.2845, 0.7694, 0.9878, ..., 0.3211, 0.0242, 0.0332]])
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
'''

下面给出详细的解释,我们知道G的输入有噪声以及条件,这里我们有100组噪声,每10组噪声的组内取值是完全相同的,但是组内的10个噪声每个噪声的条件是不同的,分别代表了数字0-9

也就是说我们希望用相同的噪声生成0-9一共十个数字,生成十组

三、执行过程

图中的红线代表一个执行流程,绿线代表一个执行流程,红色的方框为这一步反向传播的网络。因为判别器与生成器是分开训练的,用两个图来表示,左边是第一步训练判别器,右边是第二步训练生成器

  • step1:首先将样本进行输入,用BCE_loss来评估得到D_real_loss,然后将G生成的数据进行输入,同理评估得到D_fake_loss,将二者相加进行反向传播优化D。注意这一步不要优化G
  • step2:直接将G生成的数据进行输入,评估得到G_loss,反向传播优化G。注意这一步虽然是G生成的数据,但是通过D以后要与real进行求损失

四、测试

训练完后直接进行测试即可,最后测试生成的图片如下: