torch.utils.checkpoint

    注意

    checkpointing的实现方法是在向后传播期间重新运行已被checkpint的前向传播段。 所以会导致像RNG这类(模型)的持久化的状态比实际更超前。默认情况下,checkpoint包含了使用RNG状态的逻辑(例如通过dropout),与non-checkpointed传递相比,checkpointed具有更确定的输出。RNG状态的存储逻辑可能会导致一定的性能损失。如果不需要确定的输出,设置全局标志(global flag) 忽略RNG状态在checkpoint时的存取。

    checkpoint模型或模型的一部分

    checkpoint通过计算换内存空间来工作。与向后传播中存储整个计算图的所有中间激活不同的是,checkpoint不会保存中间激活部分,而是在反向传递中重新计算它们。它被应用于模型的任何部分。

    具体来说,在正向传播中,function将以torch.no_grad()方式运行 ,即不存储中间激活,但保存输入元组和 function的参数。在向后传播中,保存的输入变量以及 会被取回,并且function在正向传播中被重新计算.现在跟踪中间激活,然后使用这些激活值来计算梯度。

    Checkpointing 在 torch.autograd.grad()中不起作用, 仅作用于 .

    警告

    如果function在向后执行和前向执行不同,例如,由于某个全局变量,checkpoint版本将会不同,并且无法被检测到。

    参数:

    • function - 描述在模型的正向传递或模型的一部分中运行的内容。它也应该知道如何处理作为元组传递的输入。例如,在LSTM中,如果用户通过 ,应正确使用第一个输入作为第二个输入(activation, hidden)functionactivationhidden

    用于checkpoint sequential模型的辅助函数

    checkpointing工作方式: checkpoint().

    警告

    Checkpointing无法作用于, 只作用于torch.autograd.backward().

    参数:

    • functions – 按顺序执行的模型, 一个 对象,或者一个由modules或functions组成的list。
    • inputs – 输入,Tensor组成的元组
    Returns:按顺序返回每个*inputs的结果

    例子