pytorch获取网络中的模块和参数

1. pytorch获取网络结构

在写深度学习程序时,我们通常要将网络打印出来查看网络结构,一个最简单的方法就是直接print(model)来打印模型结构,这里我们以下面程序为例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch.nn as nn

class SubNet(nn.Module):
def __init__(self):
super(SubNet, self).__init__()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.head = SubNet()

def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(-1, 320)
x = self.head(x)
return x

net = Net()
print(net)

这里我们实例化了一个网络,然后直接print(net),得到的结果如下

1
2
3
4
5
6
7
8
Net(
(conv1): Conv2d(3, 10, kernel_size=(5, 5), stride=(1, 1))
(conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
(head): SubNet(
(fc1): Linear(in_features=320, out_features=50, bias=True)
(fc2): Linear(in_features=50, out_features=10, bias=True)
)
)

1.1 named_modules方法

上述方法可以打印出网络结构,但是我们无法获取到每一层的更具体的信息,例如这里的fc1层实际名字叫做head.fc1,我们希望迭代的获得这些信息,因此就要用到named_modules方法

named_modules()

  • 返回网络中所有模块的迭代器,迭代器生成每个模块的名字以及模块本身

下面通过例子来更加深入的理解这个方法,如下所示,继续上面的网络,我们将其每一部分输出,可以看到如下结果,以fc1层为例,方法返回head.fc1以及Linear(in_features=320, out_features=50, bias=True)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
for name, module in net.named_modules():
print(name)
print(module)

"""
Net(
(conv1): Conv2d(3, 10, kernel_size=(5, 5), stride=(1, 1))
(conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
(head): SubNet(
(fc1): Linear(in_features=320, out_features=50, bias=True)
(fc2): Linear(in_features=50, out_features=10, bias=True)
)
)
conv1
Conv2d(3, 10, kernel_size=(5, 5), stride=(1, 1))
conv2
Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
head
SubNet(
(fc1): Linear(in_features=320, out_features=50, bias=True)
(fc2): Linear(in_features=50, out_features=10, bias=True)
)
head.fc1
Linear(in_features=320, out_features=50, bias=True)
head.fc2
Linear(in_features=50, out_features=10, bias=True)
"""

1.2 named_parameters方法

named_parameters()

  • 返回一个迭代器,包含网络中所有的参数,迭代器生成每个参数的名字以及参数本身

由于这里的参数打印出来太占空间,因此我仅输出它们的尺寸

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
for name, module in net.named_parameters():
print(name)
print(module.size())

"""
conv1.weight
torch.Size([10, 3, 5, 5])
conv1.bias
torch.Size([10])
conv2.weight
torch.Size([20, 10, 5, 5])
conv2.bias
torch.Size([20])
head.fc1.weight
torch.Size([50, 320])
head.fc1.bias
torch.Size([50])
head.fc2.weight
torch.Size([10, 50])
head.fc2.bias
torch.Size([10])
"""