阿里二面追问:FP16训练如何避免NaN?

最近遇到一个训练代码,混合精度使用 apex,多卡还是 torch ddp+mp.spawn 子进程启动的方式,性能受限于 python 的 gil 锁。
其实对于混合精度训练 pytorch 已经 merge 进了 amp,fsdp 也支持了 mixed precision policy,多卡训练有 torchrun 启动器,还支持多机分布式。
 pytorch 已经有很多新的 feature,为什么不去用呢?
1️⃣起源

17 年, 提出了混合精度训练:Mixed Precision Training。
用 fp16 去表示 fp32 计算,在训练中会有一些数值问题:
🌟精度下溢/上溢
🌟fp16 数值范围和分布不匹配,导致梯度归零
首先第一个问题,fp16 的表示范围小于 fp32,因此会产生 underflow、overflow
第二个问题,对于 activation gradient 的分布,很大一部分较小的值在 fp16 下是不可表示的,会发生下溢 underflow 被置为 0,导致反向传播中梯度就丢失了。
混合精度训练中的数值问题和模型量化中遇到的数值问题其实很类似,都是从高精度表示范围映射到低精度表示范围,在量化中是通过 calibration 校准进行不同精度范围的线性映射。
在混合精度训练中是引入了 loss scaling 梯度缩放,防止非常小的梯度值在使用 fp16 进行表示时 underflow 下溢变成 0。
在前向传播计算得到损失值 loss,开始反向传播 backward 之前,对 loss 进行缩放,乘以一个大于 1 的常数,称为缩放因子 S,例如 1024、4096 等。
然后用缩放后的 loss 再去进行 backward,又由于求导是基于链式法则,反向传播过程中所有的梯度值都会被进行同等缩放。
但放大的梯度导致后续所有依赖梯度大小进行计算的操作都会失真,权重更新时 w=w-lr*grad,导致权重更新量也会被放大 S 倍,相当于变相增大了学习率 lr,导致训练过程不稳定,且与 fp32 训练的行为不一致。
因此在 backward 之后,更新模型权重之前,包括所有需要依赖梯度大小进行计算的操作之前。
要在 fp32 精度下对权重梯度进行反缩放 unscaling,除以之前放大的缩放因子 S,unscaling 之后的权重梯度也可以应用梯度裁剪和权重衰减等依赖于梯度大小值的操作。
然后 optimizer 里面会 copy 一份 fp32 的主权重 master weights 进行参数更新,更新之后的 master weights 再 cast 到 fp16 同步给模型参数。
整个混合精度训练的流程可以表示为:
✴️model fp16 weights+fp16 activations
✴️fp16 forward
✴️fp16 activations
✴️fp32 loss scaling
✴️weight&activation fp16 backward
✴️fp32 grad unscaling
✴️fp32 grad clip&decay
✴️fp32 master weights update
✴️cast to fp16 model weights
2️⃣amp
18 年 nv 以 pytorch 三方扩展的形式推出了 apex,以支持混合精度,20 年 pytorch1.6 merge 进了 torch.cuda.amp,配合 autocast 实现混合精度训练
autocast 是混合精度的上下文管理器,在 context 里面会自动选择 op 对应的计算精度,主要基于白名单机制进行自动类型转换。
3️⃣fsdp
fsdp 是 21 年 pytorch1.11 的引入的新特性,核心思想来源于 deepspeed zero,在其上又扩展了对混合精度的支持。
与 amp 相比,fsdp 灵活度更高,可以通过 FSDPModule warp 设置不同的混合精度策略。
以及为了更高精度的数值结果,在 fp16 activation 计算和模型参数 all-gather 的基础上,可以使用 fp32 进行 gradient 的 reduce-scatter 和 optimizer 的参数更新。
QA:这个时候再思考下,数值计算在什么情况下会发生 nan,迁移到混合精度训练流程里面哪些地方可能会产生 nan,怎么检测 nan,怎么避免 nan,哪些是框架已经做的,哪些还需要自己处理的,问问自己是否有了答案?
📳对于想求职算法岗的同学,如果想参加高质量项目辅导,提升面试能力,欢迎后台联系。
全部评论

相关推荐

评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务