Mixed precision training
Why mixed precision
Benefits:
- Faster in modern GPUs that support half-precision (FP16, BFLOAT16) arithmetic.
- Check the A100 specs: FP32 has 19.5 TFLOPS, while FP16 has 312 TFLOPS, 15x faster.
- Save 50% memory.
Challenges:
- Narrower range, lower precision than FP32
- Gradients could underflow or overflow, leading to training instability.
Why not FP16 everywhere?
- Some operations, e.g. loss computation, reduction operation (sum, mean), gradient computation that requires accumulation (again reduction-like ops), require higher precision. Using FP16 everywhere could make the training unstable (lots of Inf or NaN).
When to use FP16 and FP32?
- FP32 master weights: model weights are saved in FP32
- FP16 copies for computation: forward pass uses FP16 for computation.
- FP16 is also used in backward pass, except the gradient accumulation stage. The accumulation stage will cast FP16 to FP32 before accumulation.
- Operations that are kept in FP32:
- Loss Computation
- Reduction Operations
- Normalization Layers: batch norm, layer norm
autocast
in Pytorch decides whether casting to FP16 or not.
Grad Scaler
- Loss Scaling: Scales up the loss value to prevent gradients in FP16 from underflowing to zero.
- Gradient Unscaling: After backpropagation, scales down the gradients to match the original scale.
- Dynamic Adjustment: Automatically adjusts the scaling factor to balance between preventing underflow and avoiding overflow.
This post is licensed under CC BY 4.0 by the author.
Comments powered by Disqus.