1. pytorch获取网络结构
在写深度学习程序时,我们通常要将网络打印出来查看网络结构,一个最简单的方法就是直接print(model)来打印模型结构,这里我们以下面程序为例
1 | import torch.nn as nn |
这里我们实例化了一个网络,然后直接print(net),得到的结果如下
1 | Net( |
1.1 named_modules方法
上述方法可以打印出网络结构,但是我们无法获取到每一层的更具体的信息,例如这里的fc1层实际名字叫做head.fc1,我们希望迭代的获得这些信息,因此就要用到named_modules方法
named_modules()
- 返回网络中所有模块的迭代器,迭代器生成每个模块的名字以及模块本身
下面通过例子来更加深入的理解这个方法,如下所示,继续上面的网络,我们将其每一部分输出,可以看到如下结果,以fc1层为例,方法返回head.fc1以及Linear(in_features=320, out_features=50, bias=True)
1 | for name, module in net.named_modules(): |
1.2 named_parameters方法
named_parameters()
- 返回一个迭代器,包含网络中所有的参数,迭代器生成每个参数的名字以及参数本身
由于这里的参数打印出来太占空间,因此我仅输出它们的尺寸
1 | for name, module in net.named_parameters(): |