nn.Sequential && nn.ModuleList

在介绍nn.Sequential和nn.ModuleDict之前,我们需要知道在pytorch构建的model核心是nn.Module模块,下面举个例子

1
2
3
4
5
6
7
8
class model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = nn.Conv2d(3, 20, 5)

def forward(self, x):
x = F.relu(conv(x))
return x

在了解这个基本概念之后,我们分别介绍这两个模块

nn.Sequential

nn.Sequential继承自nn.Module模块,因此他自带forward函数,下面我们看一个例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
print(model)
'''
Sequential(
(0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
(1): ReLU()
(2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
(3): ReLU()
)
'''

# 给每一步的模块进行命名
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
print(model)
'''
Sequential(
(conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
(relu1): ReLU()
(conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
(relu2): ReLU()
)
'''

input = torch.randn([1, 1, 10, 10])
output = model(input)
print(output.size()) # torch.Size([1, 64, 2, 2])

如上所示,我们可以得到一些结论

  1. 在nn.Sequential里面的每一个操作是逐步执行的,不可改变顺序,如果第一步的输出与第二步的输入不匹配就会报错
  2. 可以通过OrderedDict来改变nn.Sequential里面每一步的名字。注意,即使改变了名字,索引时也需要用0,1,2…,例如model[0]=Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1),model[‘conv1’]会报错

nn.ModuleList

nn.ModuleDict没有继承自nn.Module,所以不能像nn.Sequential那样有forward功能。可以将其看做一个列表的形式,能够将多个操作存放在一个列表里

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, l in enumerate(self.linears):
x = self.linears[i](x)
return x

model = MyModule()
input = torch.randn([1, 10])
output = model(input)
print(output.size()) # torch.Size([1, 10])

如上所示,这里总结nn.ModelList的一些特点

  1. nn.ModelList是单纯的列表形式,当我们想快速构建一些操作(例如例子中的linear操作时,可以使用modellist)
  2. nn.ModelList不具备forward功能,所以我们调用里面的操作时,需要进行索引,然后才能运行这个操作
  3. nn.ModelList列表内的操作可以是乱序的,比如我先用list[3],再用list[0],而nn.Sequential的执行顺序不能打乱

为什么不能用python中的list来代替nn.ModelList呢?

因为nn.ModelList可以将里面的列表操作自动注册到整个网络中,但是如果是python的list,则会出问题,如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class net_modlist(nn.Module):
def __init__(self):
super(net_modlist, self).__init__()
self.modlist = nn.ModuleList([
nn.Conv2d(1, 20, 5),
nn.Conv2d(20, 64, 5),])

def forward(self, x):
for m in self.modlist:
x = m(x)
return x

model = net_modlist()
for param in model.parameters():
print(type(param.data), param.size())

'''
nn.ModuleList
<class 'torch.Tensor'> torch.Size([20, 1, 5, 5])
<class 'torch.Tensor'> torch.Size([20])
<class 'torch.Tensor'> torch.Size([64, 20, 5, 5])
<class 'torch.Tensor'> torch.Size([64])

将nn.ModuleList换为单纯的list
None # 输出为None,表示conv操作并没有加入到模型参数中
'''