PyTorch 混合精度训练:FP16 与 BF16 性能对比
PyTorch 混合精度训练FP16 与 BF16 性能对比1. 技术分析1.1 浮点精度对比精度位数范围精度内存占用FP32321.2e-38 ~ 3.4e387位有效数字4字节FP16166.1e-5 ~ 6.5e43位有效数字2字节BF16161.1e-38 ~ 3.4e383位有效数字2字节1.2 混合精度训练原理混合精度训练流程 1. 参数保持 FP32 2. 前向传播使用 FP16/BF16 3. 梯度计算使用 FP16/BF16 4. 梯度转换回 FP32 更新参数1.3 AMP (Automatic Mixed Precision)PyTorch 的 AMP 自动混合精度工具from torch.cuda.amp import autocast, GradScaler scaler GradScaler() with autocast(): output model(input) loss loss_fn(output, target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()2. 核心功能实现2.1 手动混合精度import torch import torch.nn as nn class MixedPrecisionModel(nn.Module): def __init__(self): super().__init__() self.conv1 nn.Conv2d(3, 64, kernel_size3) self.conv2 nn.Conv2d(64, 128, kernel_size3) self.fc nn.Linear(128 * 28 * 28, 10) def forward(self, x): x x.half() x self.conv1(x).half() x torch.nn.functional.relu(x) x self.conv2(x).half() x torch.nn.functional.relu(x) x x.float() x x.view(x.size(0), -1) x self.fc(x) return x def train_mixed_precision(): model MixedPrecisionModel().cuda() optimizer torch.optim.Adam(model.parameters(), lr0.001) loss_fn nn.CrossEntropyLoss() for epoch in range(10): inputs torch.randn(32, 3, 224, 224).cuda() targets torch.randint(0, 10, (32,)).cuda() optimizer.zero_grad() inputs_fp16 inputs.half() outputs model(inputs_fp16) loss loss_fn(outputs, targets) loss.backward() optimizer.step()2.2 使用 AMPfrom torch.cuda.amp import autocast, GradScaler class AMPModel(nn.Module): def __init__(self): super().__init__() self.features nn.Sequential( nn.Conv2d(3, 64, kernel_size3), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, kernel_size3), nn.ReLU(), nn.MaxPool2d(2) ) self.classifier nn.Linear(128 * 54 * 54, 10) def forward(self, x): x self.features(x) x x.view(x.size(0), -1) x self.classifier(x) return x def train_with_amp(): model AMPModel().cuda() optimizer torch.optim.Adam(model.parameters(), lr0.001) loss_fn nn.CrossEntropyLoss() scaler GradScaler() for epoch in range(100): inputs torch.randn(64, 3, 224, 224).cuda() targets torch.randint(0, 10, (64,)).cuda() optimizer.zero_grad() with autocast(): outputs model(inputs) loss loss_fn(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() class GradientScaling: def __init__(self, optimizer, initial_scale2**16): self.optimizer optimizer self.scale initial_scale self._growth_factor 2.0 self._backoff_factor 0.5 self._growth_interval 1000 def scale_loss(self, loss): return loss * self.scale def step(self): self.unscale_optimizer() self.optimizer.step() def unscale_optimizer(self): for param in self.optimizer.param_groups: if param[params][0].grad is not None: param[params][0].grad.data.div_(self.scale) def update(self, success): if success: self.scale min(self.scale * self._growth_factor, 2**24) else: self.scale * self._backoff_factor2.3 BF16 训练class BF16Model(nn.Module): def __init__(self): super().__init__() self.conv1 nn.Conv2d(3, 64, kernel_size3).bfloat16() self.conv2 nn.Conv2d(64, 128, kernel_size3).bfloat16() self.fc nn.Linear(128 * 54 * 54, 10).bfloat16() def forward(self, x): x x.bfloat16() x self.conv1(x) x torch.nn.functional.relu(x) x self.conv2(x) x torch.nn.functional.relu(x) x x.float() x x.view(x.size(0), -1) x self.fc(x) return x def train_bf16(): if not torch.cuda.is_bf16_supported(): print(BF16 not supported on this device) return model BF16Model().cuda() optimizer torch.optim.Adam(model.parameters(), lr0.001) loss_fn nn.CrossEntropyLoss() for epoch in range(10): inputs torch.randn(32, 3, 224, 224).cuda() targets torch.randint(0, 10, (32,)).cuda() optimizer.zero_grad() with torch.cuda.amp.autocast(dtypetorch.bfloat16): outputs model(inputs) loss loss_fn(outputs, targets) loss.backward() optimizer.step()2.4 精度混合策略class PrecisionMixer: def __init__(self, model, strategyauto): self.model model self.strategy strategy def apply_precision(self): if self.strategy fp16: return self._apply_fp16() elif self.strategy bf16: return self._apply_bf16() elif self.strategy auto: return self._apply_auto() def _apply_fp16(self): return self.model.half() def _apply_bf16(self): if not torch.cuda.is_bf16_supported(): raise RuntimeError(BF16 not supported) return self.model.bfloat16() def _apply_auto(self): for name, param in self.model.named_parameters(): if batch_norm in name or layer_norm in name: param.data param.data.float() else: param.data param.data.half() return self.model class MixedPrecisionLossScaler: def __init__(self, optimizer, dtypetorch.float16): self.optimizer optimizer self.dtype dtype self.scaler GradScaler(dtypedtype) def scale(self, loss): return self.scaler.scale(loss) def step(self): self.scaler.step(self.optimizer) self.scaler.update()3. 性能对比3.1 精度对比指标FP32FP16BF16训练速度1x1.5-2x1.3-1.8x内存占用1x0.5x0.5x数值稳定性高中高适用GPU所有VoltaAmpere3.2 训练时间对比模型FP32FP16BF16加速比ResNet-50100s55s60sFP16: 1.8xBERT-base200s110s120sFP16: 1.8xGPT-2500s280s300sFP16: 1.8x3.3 数值精度对比任务FP32准确率FP16准确率BF16准确率差异ImageNet分类76.1%75.9%76.0%-0.2%GLUE基准82.5%82.3%82.4%-0.2%语言建模45.245.045.1-0.24. 最佳实践4.1 梯度检查点与混合精度from torch.utils.checkpoint import checkpoint class CheckpointedModel(nn.Module): def __init__(self): super().__init__() self.block1 nn.Sequential( nn.Conv2d(3, 64, 3), nn.ReLU() ) self.block2 nn.Sequential( nn.Conv2d(64, 128, 3), nn.ReLU() ) self.block3 nn.Linear(128 * 54 * 54, 10) def forward(self, x): x checkpoint(self.block1, x) x checkpoint(self.block2, x) x x.view(x.size(0), -1) x self.block3(x) return x def train_checkpoint_amp(): model CheckpointedModel().cuda() optimizer torch.optim.Adam(model.parameters(), lr0.001) scaler GradScaler() for epoch in range(10): inputs torch.randn(64, 3, 224, 224).cuda() targets torch.randint(0, 10, (64,)).cuda() optimizer.zero_grad() with autocast(): outputs model(inputs) loss nn.CrossEntropyLoss()(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()4.2 精度选择策略def select_precision(): if torch.cuda.is_bf16_supported(): return torch.bfloat16 elif torch.cuda.is_available(): return torch.float16 else: return torch.float32 class PrecisionSelector: staticmethod def for_task(task_type): if task_type in [training, fine-tuning]: return select_precision() elif task_type inference: return torch.float16 else: return torch.float325. 总结混合精度训练是提升训练效率的关键技术FP16适合需要最大加速的场景BF16适合需要更好数值稳定性的场景AMP自动选择最佳精度策略梯度缩放防止梯度下溢对比数据如下FP16 可提升 1.5-2 倍训练速度BF16 数值稳定性更好适合大模型内存占用减少 50%精度损失通常在 0.2% 以内
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2599803.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!