使用自定义损失函数时,如何在 PyTorch 中执行内存高效的反向传播?

How to Perform Memory-efficient Backpropagation in PyTorch When Using a Custom Loss Function?

提问人:utkutpcgl 提问时间:10/21/2023 更新时间:10/22/2023 访问量:23

问:

简介

我正在使用 PyTorch 进行一个大规模的深度学习项目,并在反向传播过程中遇到内存问题。我已经实现了一个自定义损失函数,我需要知道是否有一种更节省内存的方法来执行反向传播,而不会影响自定义损失计算。

代码

import torch
import torch.nn as nn

class CustomLoss(nn.Module):
    def forward(self, x, y):
        return torch.sum(x * y)

# My neural network
class Net(nn.Module):
    # ...

我尝试使用 PyTorch 的内置方法进行反向传播,但它们会消耗大量内存。我原以为可以优化自定义损失函数以提高内存利用率。

到底发生了什么?

在反向传播期间,内存消耗激增,导致我的脚本崩溃。

Python PyTorch 反向传播

评论


答:

0赞 Iskander14yo 10/21/2023 #1

考虑:

  1. 使更小。batch_size
  2. 使用(计算和内存之间的权衡,请参阅此处)创建检查点torch.utils.checkpoint.checkpoint)
  3. 尽可能使用就地操作。
0赞 Karl 10/22/2023 #2

backprop 期间的内存消耗是由于需要存储每个模型参数的 grad 参数,以及为优化器状态存储的参数。它与您的损失函数无关。

您可以通过使用较小的批大小和梯度累积训练来减少内存使用量。