Pytorch 自动混合精度 - 将整个代码块转换为 float32

Pytorch automatic mixed precision - cast whole code block to float32

提问人:The Guy with The Hat 提问时间:8/11/2023 最后编辑:The Guy with The Hat 更新时间:8/11/2023 访问量:228

问:

我有一个复杂的模型,我想以混合精度进行训练。为此,我使用 torch.amp 包。我可以使用 为整个模型启用 AMP。但是,模型训练不稳定,因此我想强制模型的某些区域浮动32。with torch.cuda.amp.autocast(enabled=enable_amp, dtype=torch.float16):

以下是我尝试或考虑过的:

据我所知,有两种官方认可的解决方案:禁用模块的 AMP 并在模块开始时转换所有输入张量,或者按照本答案中的说明使用。但是,这两者都有问题。第一种需要手动将每个输入张量转换为 float32。第二个需要将装饰器添加到转发函数中,因此我需要将其单独添加到每个模块中,或者创建一个包含其他模块的新容器模块。这两种解决方案都不适合我,因为我想为模型的许多不同部分测试启用和禁用 float16,因此我需要不断添加和删除代码来转换数十个张量和/或模块。custom_fwdcustom_fwd

我想要的是能够为整个代码块(如 )转换为 float32,但我不知道如何可靠地做到这一点。 不会将 float16 张量转换为 float32,它只会禁用将 float32 张量转换为 float16。 似乎可能有效,但这不是官方记录的用法,并且根据文档,我不相信它将来会可靠地工作。这个模型将继续被一群人使用/更新多年,所以如果 pytorch 更新更改了未记录的功能,我不想冒着将来中断它的风险。with [cast everything to float32]:with torch.cuda.amp.autocast(enabled=False):with torch.cuda.amp.autocast(dtype=torch.float32):

有谁知道一种方法可以可靠地将整个代码块中的所有内容转换为 float32?

Python PyTorch 强制转换 自动混合精度

评论


答: 暂无答案