PyTorch实战:从零构建ResNet50模型(训练、测试与ONNX转换全流程)
1. ResNet50模型基础认知ResNet50是计算机视觉领域的里程碑式模型它的核心创新在于残差连接Residual Connection设计。想象一下你在学习骑自行车时如果每次摔倒都能记住这次比上次多骑了2米这种持续积累经验的方式就是残差连接的思想——让网络能够保留之前学到的特征只学习新增的差异部分。为什么选择CIFAR10数据集这个包含6万张32x32小图像的数据集就像机器学习界的九九乘法表10个类别飞机、汽车、鸟等5万训练图1万测试图图像尺寸小但足够验证模型有效性PyTorch实现ResNet50需要准备三个关键文件resnet50.py模型架构定义train.py训练流程test_pth.py模型测试提示实际工业部署时通常会先在小数据集验证模型结构再迁移到ImageNet等大数据集2. 模型搭建详解2.1 残差块实现ResNet50包含两种基础模块代码示例精简版class BasicBlock(nn.Module): def __init__(self, in_channel, outs, stride1): super().__init__() self.conv1 nn.Conv2d(in_channel, outs[0], kernel_size1, stridestride) self.conv2 nn.Conv2d(outs[0], outs[1], kernel_size3, padding1) self.conv3 nn.Conv2d(outs[1], outs[2], kernel_size1) # 捷径连接当维度不匹配时 self.shortcut nn.Sequential() if stride !1 or in_channel!outs[2]: self.shortcut nn.Sequential( nn.Conv2d(in_channel, outs[2], kernel_size1, stridestride), nn.BatchNorm2d(outs[2]) ) def forward(self, x): out F.relu(self.conv1(x)) out F.relu(self.conv2(out)) out self.conv3(out) out self.shortcut(x) # 残差连接 return F.relu(out)关键设计要点瓶颈结构1x1卷积先降维再升维如256-64-256减少计算量恒等映射当输入输出维度相同时shortcut直接传递输入下采样通过stride2的卷积实现特征图尺寸减半2.2 完整网络结构ResNet50的层次结构像搭积木class ResNet50(nn.Module): def __init__(self): super().__init__() self.conv1 nn.Conv2d(3, 64, kernel_size7, stride2, padding3) self.maxpool nn.MaxPool2d(kernel_size3, stride2, padding1) # 四个阶段每个阶段包含多个残差块 self.layer1 self._make_layer(64, [64,64,256], 3, stride1) self.layer2 self._make_layer(256, [128,128,512], 4, stride2) self.layer3 self._make_layer(512, [256,256,1024], 6, stride2) self.layer4 self._make_layer(1024, [512,512,2048], 3, stride2) self.avgpool nn.AdaptiveAvgPool2d((1,1)) self.fc nn.Linear(2048, 10) # CIFAR10有10类3. 训练全流程实战3.1 数据预处理CIFAR10图像需要调整到224x224原论文尺寸transform transforms.Compose([ transforms.Resize(224), # 小图放大会有信息损失 transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) # ImageNet统计值 ])实际项目中更推荐# 更合理的小图处理方式 transform transforms.Compose([ transforms.RandomCrop(32, padding4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # CIFAR10统计值 ])3.2 训练关键参数# 初始化 model ResNet50().to(device) criterion nn.CrossEntropyLoss() optimizer optim.Adam(model.parameters(), lr0.001) # 训练循环 for epoch in range(10): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target data.to(device), target.to(device) optimizer.zero_grad() output model(data) loss criterion(output, target) loss.backward() optimizer.step() if batch_idx % 100 0: print(fEpoch: {epoch} [{batch_idx}/{len(train_loader)}] Loss: {loss.item():.4f})推荐使用学习率预热策略scheduler optim.lr_scheduler.OneCycleLR( optimizer, max_lr0.01, steps_per_epochlen(train_loader), epochs10 )4. 模型测试与评估4.1 准确率测试def evaluate(model, test_loader): model.eval() correct 0 with torch.no_grad(): for data, target in test_loader: data, target data.to(device), target.to(device) output model(data) pred output.argmax(dim1) correct (pred target).sum().item() acc 100. * correct / len(test_loader.dataset) print(fAccuracy: {acc:.2f}%) return acc4.2 单图预测示例def predict(image_path): img Image.open(image_path).convert(RGB) img_tensor transform(img).unsqueeze(0).to(device) model.eval() with torch.no_grad(): output model(img_tensor) prob F.softmax(output, dim1) print(fPredicted: {classes[output.argmax()]} with {prob.max().item()*100:.2f}% confidence)5. 模型转换与部署5.1 PyTorch转ONNXdummy_input torch.randn(1, 3, 224, 224).to(device) torch.onnx.export( model, dummy_input, resnet50.onnx, input_names[input], output_names[output], dynamic_axes{ input: {0: batch_size}, output: {0: batch_size} } )5.2 ONNX模型验证import onnxruntime as ort sess ort.InferenceSession(resnet50.onnx) input_name sess.get_inputs()[0].name # 测试数据预处理 def preprocess(image_path): img Image.open(image_path) img transform(img).numpy() return img[np.newaxis, ...] # 添加batch维度 output sess.run(None, {input_name: preprocess(test.jpg)}) print(np.argmax(output[0]))常见问题处理出现Exporting the operator ... to ONNX opset version 11 is not supported时更新PyTorch版本或指定更低opset版本opset_version106. 性能优化技巧混合精度训练scaler torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): output model(data) loss criterion(output, target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()梯度累积小batch场景accum_steps 4 for batch_idx, (data, target) in enumerate(train_loader): with torch.cuda.amp.autocast(): output model(data) loss criterion(output, target) / accum_steps scaler.scale(loss).backward() if (batch_idx1) % accum_steps 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad()模型剪枝部署优化from torch.nn.utils import prune parameters_to_prune [ (module, weight) for module in filter( lambda m: isinstance(m, nn.Conv2d), model.modules()) ] prune.global_unstructured( parameters_to_prune, pruning_methodprune.L1Unstructured, amount0.2 # 剪枝20% )
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2507425.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!