PyTorch实战:从零构建ResNet50模型(CIFAR10训练+测试+ONNX转换)
1. ResNet50模型基础认知第一次接触ResNet50时我被它的残差连接设计惊艳到了。传统神经网络随着层数增加会出现梯度消失问题而ResNet通过跨层直连通道让信息能够无损传递到更深层。这就好比在高速公路上设置应急车道即使主路拥堵车辆仍能通过应急道快速通行。ResNet50名称中的50代表50层深度实际包含49个卷积层和1个全连接层其核心结构是下图所示的残差块。每个残差块包含1x1卷积降维3x3卷积特征提取1x1卷积升维跨层连接identity shortcutclass Bottleneck(nn.Module): expansion 4 def __init__(self, in_channels, channels, stride1): super().__init__() self.conv1 nn.Conv2d(in_channels, channels, kernel_size1) self.bn1 nn.BatchNorm2d(channels) self.conv2 nn.Conv2d(channels, channels, kernel_size3, stridestride, padding1) self.bn2 nn.BatchNorm2d(channels) self.conv3 nn.Conv2d(channels, channels*self.expansion, kernel_size1) self.bn3 nn.BatchNorm2d(channels*self.expansion) self.shortcut nn.Sequential() if stride ! 1 or in_channels ! channels*self.expansion: self.shortcut nn.Sequential( nn.Conv2d(in_channels, channels*self.expansion, kernel_size1, stridestride), nn.BatchNorm2d(channels*self.expansion) ) def forward(self, x): out F.relu(self.bn1(self.conv1(x))) out F.relu(self.bn2(self.conv2(out))) out self.bn3(self.conv3(out)) out self.shortcut(x) return F.relu(out)2. PyTorch环境搭建推荐使用conda创建独立的Python环境避免包版本冲突。这是我验证过的稳定版本组合conda create -n resnet python3.8 conda activate resnet pip install torch1.12.1 torchvision0.13.1 pip install numpy pandas matplotlib tqdm关键工具说明torch.nn神经网络层实现torch.optim优化算法如SGD、Adamtorchvision.transforms图像预处理torch.utils.data数据加载与批处理验证GPU是否可用import torch print(torch.cuda.is_available()) # 输出True表示GPU可用 device torch.device(cuda if torch.cuda.is_available() else cpu)3. CIFAR10数据处理实战CIFAR10包含6万张32x32彩色图片5万训练1万测试共10个类别。处理流程如下3.1 数据集加载与增强from torchvision import datasets, transforms # 定义增强策略 train_transform transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.Resize(224), # ResNet原始输入尺寸 transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ]) test_transform transforms.Compose([ transforms.Resize(224), transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ]) # 加载数据集 train_set datasets.CIFAR10(root./data, trainTrue, downloadTrue, transformtrain_transform) test_set datasets.CIFAR10(root./data, trainFalse, downloadTrue, transformtest_transform) # 创建数据加载器 train_loader torch.utils.data.DataLoader(train_set, batch_size64, shuffleTrue, num_workers4) test_loader torch.utils.data.DataLoader(test_set, batch_size64, shuffleFalse, num_workers4)3.2 数据可视化检查import matplotlib.pyplot as plt classes (plane, car, bird, cat, deer, dog, frog, horse, ship, truck) def imshow(img): img img * 0.5 0.5 # 反归一化 npimg img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show() # 获取随机批次 dataiter iter(train_loader) images, labels next(dataiter) # 显示图像 imshow(torchvision.utils.make_grid(images[:4])) print( .join(f{classes[labels[j]]:5s} for j in range(4)))4. 完整ResNet50实现4.1 网络结构搭建class ResNet50(nn.Module): def __init__(self, num_classes10): super().__init__() self.in_channels 64 self.conv1 nn.Conv2d(3, 64, kernel_size7, stride2, padding3) self.bn1 nn.BatchNorm2d(64) self.maxpool nn.MaxPool2d(kernel_size3, stride2, padding1) # 残差块组 self.layer1 self._make_layer(64, 3, stride1) self.layer2 self._make_layer(128, 4, stride2) self.layer3 self._make_layer(256, 6, stride2) self.layer4 self._make_layer(512, 3, stride2) self.avgpool nn.AdaptiveAvgPool2d((1, 1)) self.fc nn.Linear(512 * Bottleneck.expansion, num_classes) def _make_layer(self, channels, num_blocks, stride): strides [stride] [1]*(num_blocks-1) layers [] for stride in strides: layers.append(Bottleneck(self.in_channels, channels, stride)) self.in_channels channels * Bottleneck.expansion return nn.Sequential(*layers) def forward(self, x): x F.relu(self.bn1(self.conv1(x))) x self.maxpool(x) x self.layer1(x) x self.layer2(x) x self.layer3(x) x self.layer4(x) x self.avgpool(x) x torch.flatten(x, 1) x self.fc(x) return x4.2 模型初始化技巧def initialize_weights(model): for m in model.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, modefan_out, nonlinearityrelu) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) model ResNet50().to(device) initialize_weights(model)5. 训练与评估流程5.1 训练配置criterion nn.CrossEntropyLoss() optimizer torch.optim.SGD(model.parameters(), lr0.1, momentum0.9, weight_decay5e-4) scheduler torch.optim.lr_scheduler.StepLR(optimizer, step_size30, gamma0.1)5.2 训练循环def train(epoch): model.train() running_loss 0.0 for batch_idx, (inputs, targets) in enumerate(train_loader): inputs, targets inputs.to(device), targets.to(device) optimizer.zero_grad() outputs model(inputs) loss criterion(outputs, targets) loss.backward() optimizer.step() running_loss loss.item() if batch_idx % 100 99: print(fEpoch: {epoch}, Batch: {batch_idx1}, Loss: {running_loss/100:.3f}) running_loss 0.0 def test(): model.eval() correct 0 total 0 with torch.no_grad(): for inputs, targets in test_loader: inputs, targets inputs.to(device), targets.to(device) outputs model(inputs) _, predicted outputs.max(1) total targets.size(0) correct predicted.eq(targets).sum().item() acc 100. * correct / total print(fTest Accuracy: {acc:.2f}%) return acc for epoch in range(100): train(epoch) test() scheduler.step()6. 模型转换与部署6.1 PyTorch转ONNXdummy_input torch.randn(1, 3, 224, 224).to(device) torch.onnx.export(model, dummy_input, resnet50.onnx, input_names[input], output_names[output], dynamic_axes{input: {0: batch}, output: {0: batch}})6.2 ONNX模型验证import onnxruntime as ort ort_session ort.InferenceSession(resnet50.onnx) outputs ort_session.run(None, {input: dummy_input.cpu().numpy()}) # 对比原始模型输出 with torch.no_grad(): torch_output model(dummy_input) print(Output difference:, np.max(np.abs(outputs[0] - torch_output.cpu().numpy())))6.3 ONNX模型推理示例def preprocess_image(image_path): image Image.open(image_path).convert(RGB) transform transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ]) return transform(image).unsqueeze(0).numpy() def predict_onnx(image_path): input_data preprocess_image(image_path) outputs ort_session.run(None, {input: input_data}) pred np.argmax(outputs[0]) return classes[pred] print(predict_onnx(test_cat.jpg)) # 输出预测类别7. 性能优化技巧混合精度训练scaler torch.cuda.amp.GradScaler() for inputs, targets in train_loader: inputs, targets inputs.to(device), targets.to(device) with torch.cuda.amp.autocast(): outputs model(inputs) loss criterion(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()梯度累积小批量场景accumulation_steps 4 optimizer.zero_grad() for i, (inputs, targets) in enumerate(train_loader): inputs, targets inputs.to(device), targets.to(device) outputs model(inputs) loss criterion(outputs, targets) / accumulation_steps loss.backward() if (i1) % accumulation_steps 0: optimizer.step() optimizer.zero_grad()学习率预热from torch.optim.lr_scheduler import LambdaLR warmup_epochs 5 scheduler LambdaLR(optimizer, lr_lambdalambda epoch: (epoch1)/warmup_epochs if epoch warmup_epochs else 0.1**(epoch//30))
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2454829.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!