4.5 读取和存储

    我们可以直接使用函数和load函数分别存储和读取Tensorsave使用Python的pickle实用程序将对象进行序列化,然后将序列化的对象保存到disk,使用save可以保存各种对象,包括模型、张量和字典等。而load使用pickle unpickle工具将pickle的对象文件反序列化为内存。

    下面的例子创建了Tensor变量x,并将其存在文件名同为x.pt的文件里。

    然后我们将数据从存储的文件读回内存。

    1. x2 = torch.load('x.pt')
    2. x2

    输出:

    1. tensor([1., 1., 1.])

    我们还可以存储一个Tensor列表并读回内存。

    1. y = torch.zeros(4)
    2. torch.save([x, y], 'xy.pt')
    3. xy_list = torch.load('xy.pt')
    4. xy_list

    输出:

    1. [tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]

    输出:

      在PyTorch中,Module的可学习参数(即权重和偏差),模块模型包含在参数中(通过访问)。state_dict是一个从参数名称隐射到参数Tesnor的字典对象。

      1. class MLP(nn.Module):
      2. def __init__(self):
      3. super(MLP, self).__init__()
      4. self.hidden = nn.Linear(3, 2)
      5. self.act = nn.ReLU()
      6. self.output = nn.Linear(2, 1)
      7. def forward(self, x):
      8. a = self.act(self.hidden(x))
      9. return self.output(a)
      10. net = MLP()
      11. net.state_dict()

      输出:

      1. OrderedDict([('hidden.weight', tensor([[ 0.2448, 0.1856, -0.5678],
      2. [ 0.2030, -0.2073, -0.0104]])),
      3. ('output.weight', tensor([[-0.4556, 0.4084]])),
      4. ('output.bias', tensor([-0.3573]))])

      注意,只有具有可学习参数的层(卷积层、线性层等)才有中的条目。优化器(optim)也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。

      1. optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
      2. optimizer.state_dict()

      输出:

      4.5.2.2 保存和加载模型

      PyTorch中保存和加载训练模型有两种常见的方法:

      • 仅保存和加载模型参数(state_dict);
      • 保存和加载整个模型。

      1. 保存和加载state_dict(推荐方式)

      1. torch.save(model.state_dict(), PATH) # 推荐的文件后缀名是pt或pth

      加载:

      1. model = TheModelClass(*args, **kwargs)
      2. model.load_state_dict(torch.load(PATH))

      2. 保存和加载整个模型

      保存:

      1. torch.save(model, PATH)

      加载:

      1. model = torch.load(PATH)

      我们采用推荐的方法一来实验一下:

      输出:

      1. tensor([[1],
      2. [1]], dtype=torch.uint8)

      因为这netnet2都有同样的模型参数,那么对同一个输入X的计算结果将会是一样的。上面的输出也验证了这一点。

      • 通过save函数和load函数可以很方便地读写Tensor
      • 通过save函数和load_state_dict函数可以很方便地读写模型的参数。