保存与加载模型
首先给出PyTorch官网的两个教程:
==这里讲一种常用的方法==
保存&&加载
torch.save(x, path)
- x:要保存的信息
- path:保存的路径
注意这个x可以是一个简单的Tensor,也可以是我们的模型参数
torch.load(path)
- path:要加载的模型路径
此函数返回和之前保存的一模一样的x信息,即之前保存的x是什么,这个函数就返回什么
这里举两个例子方便理解,一个是Tensor
的例子,另一个是Model
的例子
(1)Tensor
1 | import torch |
此时当前文件目录下会出现tensor.pth
文件,也就是说我们用torch.save()
保存了变量x
,然后用torch.load()
加载赋值给y
输出
(2)Model
在训练模型的时候,我们往往需要保存模型的epoch
,model参数
以及optimizer
的信息,保存的代码如下
1 | torch.save({'epoch': epoch, |
重新加载模型的程序如下
1 | # 加载模型参数 |
注意上面的load_checkpoint函数,如果在训练时用了DataParallel
函数,那么最终参数会带有module
,此时就应该将其去掉
没有使用DataParallel
的参数形式

使用DataParallel
的参数形式,可以发现参数前带有module

我们在保存模型时都保存了些什么呢?下面程序展示了保存的模型和优化器的一些信息,从输出可以看出,我们传入torch.save()
中的就是模型中卷积等的weight
和bias
等信息。那么为什么使用DataParallel
之后加载参数需要去掉module
呢,这是因为我们真实的模型中是没有module
这个前缀的,是conv1.weight
或者conv1.bias
,而我们使用并行计算时,参数就会被归到module
下,就变为了module.conv1.weight
以及module.conv1.bias
,如果在load的时候不把前缀module.
去掉,模型就无法匹配参数,也就没法恢复了,所以在恢复参数的时候要注意索引是否一致
1 | import torch.nn as nn |