pytorch查看模型参数

@TOC

一、模型构建

首先随机构建一个网络模型,随后的state_dict()以及named_parameters都是在模型之后运行的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import torch.nn as nn
from torch.nn.modules.activation import ReLU
from torch.nn.modules.batchnorm import BatchNorm1d


torch.manual_seed(0)
class Model(nn.Module):

def __init__(self, in_channel=2, out_channel=4):
super(Model, self).__init__()


self.l1 = nn.Sequential(
nn.Linear(in_channel, out_channel),
nn.BatchNorm1d(out_channel),
nn.ReLU())

def forward(self, x):
y = self.l1(x)
return y


model = Model()
input = torch.randn(2, 2)
y = model(input)
print(y)
'''
tensor([[0.9999, 0.0000, 0.0000, 1.0000],
[0.0000, 1.0000, 0.9960, 0.0000]], grad_fn=<ReluBackward0>)
'''

二、state_dict()

model.state_dict()

返回一个字典,里面包含了整个模型参数,包括buffer

Returns a dictionary containing a whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
print(model.state_dict().keys())
'''
odict_keys(['l1.0.weight', 'l1.0.bias', 'l1.1.weight', 'l1.1.bias', 'l1.1.running_mean', 'l1.1.running_var', 'l1.1.num_batches_tracked'])
'''


print(model.state_dict())
'''
OrderedDict([('l1.0.weight', tensor([[-0.0053, 0.3793],
[-0.5820, -0.5204],
[-0.2723, 0.1896],
[-0.0140, 0.5607]])),
('l1.0.bias', tensor([-0.0628, 0.1871, -0.2137, -0.1390])),
('l1.1.weight', tensor([1., 1., 1., 1.])),
('l1.1.bias', tensor([0., 0., 0., 0.])),
('l1.1.running_mean', tensor([ 0.0021, 0.0166, -0.0129, -0.0015])),
('l1.1.running_var', tensor([0.9108, 0.9844, 0.9002, 0.9231])),
('l1.1.num_batches_tracked', tensor(1))])
'''


for param in model.state_dict():
print(param, "\r\t\t\t\t", model.state_dict()[param])
'''
l1.0.weight tensor([[-0.0053, 0.3793],
[-0.5820, -0.5204],
[-0.2723, 0.1896],
[-0.0140, 0.5607]])
l1.0.bias tensor([-0.0628, 0.1871, -0.2137, -0.1390])
l1.1.weight tensor([1., 1., 1., 1.])
l1.1.bias tensor([0., 0., 0., 0.])
l1.1.running_mean tensor([ 0.0021, 0.0166, -0.0129, -0.0015])
l1.1.running_var tensor([0.9108, 0.9844, 0.9002, 0.9231])
l1.1.num_batches_tracked tensor(1)
'''

三、named_parameters()

named_parameters(prefix='', recurse=True)

这个方法和上面的state_dict()相同,都是含有模型的参数,不过这个方法返回的是一个迭代器

  • prefix:在参数名字前面加上前缀
  • recurse:如果为True,产生的参数包括当前模型以及子模型,如果为False,只包含当前模型

返回一个迭代器包括所有的模型参数,包括参数名以及参数值

(string, Parameter) – Tuple containing the name and parameter

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

注意:这里还有一个parameters()方法,二者的唯一区别是paramters()方法只包含参数的迭代

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
print(model.parameters())
print(model.named_parameters())
'''
<generator object Module.parameters at 0x7f5b7fc42b30>
<generator object Module.named_parameters at 0x7f5b7fc42b30>
'''


for param in model.parameters():
print(param)
'''
Parameter containing: tensor([[-0.0053, 0.3793],
[-0.5820, -0.5204],
[-0.2723, 0.1896],
[-0.0140, 0.5607]], requires_grad=True)
Parameter containing: tensor([-0.0628, 0.1871, -0.2137, -0.1390], requires_grad=True)
Parameter containing: tensor([1., 1., 1., 1.], requires_grad=True)
Parameter containing: tensor([0., 0., 0., 0.], requires_grad=True)
'''


for name, param in model.named_parameters(prefix="xxxxxxx"):
print(name, param)
'''
xxxxxxx.l1.0.weight
Parameter containing: tensor([[-0.0053, 0.3793],
[-0.5820, -0.5204],
[-0.2723, 0.1896],
[-0.0140, 0.5607]], requires_grad=True)
xxxxxxx.l1.0.bias
Parameter containing: tensor([-0.0628, 0.1871, -0.2137, -0.1390], requires_grad=True)
xxxxxxx.l1.1.weight
Parameter containing: tensor([1., 1., 1., 1.], requires_grad=True)
xxxxxxx.l1.1.bias
Parameter containing: tensor([0., 0., 0., 0.], requires_grad=True)
'''


print("----------------model parameters---------------------")
for name, param in model.named_parameters(prefix=""):
if 'bias' in name:
print(name, param.size())
'''
l1.0.bias torch.Size([4])
l1.1.bias torch.Size([4])
'''