PyTorch 2.8 实战案例:快速训练一个图像分类模型(附代码)
PyTorch 2.8 实战案例快速训练一个图像分类模型附代码1. 引言图像分类是计算机视觉领域最基础也最实用的任务之一。无论是识别猫狗照片、检测医学影像还是分析卫星图像都需要可靠的分类模型作为基础。本文将带你使用PyTorch 2.8快速搭建并训练一个图像分类模型整个过程只需要不到100行代码。通过本教程你将学会如何用PyTorch加载和处理图像数据如何构建一个简单的卷积神经网络(CNN)如何设置训练流程并监控模型表现如何保存和加载训练好的模型即使你是PyTorch新手也能在30分钟内完成整个流程。我们将使用CIFAR-10数据集它包含6万张32x32的彩色图片分为10个类别飞机、汽车、鸟等。2. 环境准备与数据加载2.1 安装PyTorch 2.8确保你已经安装了PyTorch 2.8环境。可以使用以下命令检查版本import torch print(torch.__version__) # 应该输出2.8.x如果没有安装可以通过pip快速安装pip install torch torchvision -i https://pypi.tuna.tsinghua.edu.cn/simple2.2 加载CIFAR-10数据集PyTorch的torchvision库提供了便捷的数据集加载功能import torchvision import torchvision.transforms as transforms # 定义数据预处理 transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # 加载训练集和测试集 trainset torchvision.datasets.CIFAR10(root./data, trainTrue, downloadTrue, transformtransform) trainloader torch.utils.data.DataLoader(trainset, batch_size32, shuffleTrue, num_workers2) testset torchvision.datasets.CIFAR10(root./data, trainFalse, downloadTrue, transformtransform) testloader torch.utils.data.DataLoader(testset, batch_size32, shuffleFalse, num_workers2) classes (plane, car, bird, cat, deer, dog, frog, horse, ship, truck)这段代码会自动下载CIFAR-10数据集约170MB并将其分为训练集5万张和测试集1万张。我们使用了简单的数据归一化处理。3. 构建CNN模型3.1 定义网络结构我们将构建一个简单的CNN模型包含3个卷积层和2个全连接层import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self): super().__init__() self.conv1 nn.Conv2d(3, 32, 3, padding1) self.conv2 nn.Conv2d(32, 64, 3, padding1) self.conv3 nn.Conv2d(64, 128, 3, padding1) self.pool nn.MaxPool2d(2, 2) self.fc1 nn.Linear(128 * 4 * 4, 512) self.fc2 nn.Linear(512, 10) def forward(self, x): x self.pool(F.relu(self.conv1(x))) x self.pool(F.relu(self.conv2(x))) x self.pool(F.relu(self.conv3(x))) x torch.flatten(x, 1) # 展平所有维度除了batch x F.relu(self.fc1(x)) x self.fc2(x) return x net Net()这个网络结构虽然简单但已经足够在CIFAR-10上取得不错的效果。关键点包括使用ReLU激活函数增加非线性每层卷积后接最大池化层减少尺寸最后两个全连接层完成分类3.2 检查模型结构我们可以打印模型结构确保各层参数正确print(net)输出将显示各层的详细配置包括输入/输出通道数、核大小等。4. 训练模型4.1 设置损失函数和优化器我们使用交叉熵损失和Adam优化器import torch.optim as optim criterion nn.CrossEntropyLoss() optimizer optim.Adam(net.parameters(), lr0.001)Adam优化器通常比传统的SGD表现更好特别是在初始学习阶段。4.2 训练循环下面是完整的训练代码for epoch in range(10): # 训练10个epoch running_loss 0.0 for i, data in enumerate(trainloader, 0): # 获取输入数据 inputs, labels data # 梯度清零 optimizer.zero_grad() # 前向传播 反向传播 优化 outputs net(inputs) loss criterion(outputs, labels) loss.backward() optimizer.step() # 打印统计信息 running_loss loss.item() if i % 500 499: # 每500个batch打印一次 print(f[{epoch 1}, {i 1:5d}] loss: {running_loss / 500:.3f}) running_loss 0.0 print(Finished Training)这段代码会遍历整个训练集10次10个epoch每次处理32张图片batch_size32计算损失并反向传播更新权重定期打印损失值方便监控训练进度4.3 保存训练好的模型训练完成后我们可以保存模型权重供以后使用PATH ./cifar_net.pth torch.save(net.state_dict(), PATH)5. 测试模型性能5.1 加载测试集评估让我们看看模型在测试集上的表现correct 0 total 0 with torch.no_grad(): for data in testloader: images, labels data outputs net(images) _, predicted torch.max(outputs.data, 1) total labels.size(0) correct (predicted labels).sum().item() print(fAccuracy on 10,000 test images: {100 * correct / total:.2f}%)经过10个epoch的训练这个简单模型在测试集上的准确率应该能达到约75%。虽然不算很高但对于一个只有几层的网络来说已经不错了。5.2 查看各类别的准确率我们可以进一步分析模型在各个类别上的表现class_correct list(0. for i in range(10)) class_total list(0. for i in range(10)) with torch.no_grad(): for data in testloader: images, labels data outputs net(images) _, predicted torch.max(outputs, 1) c (predicted labels).squeeze() for i in range(len(labels)): label labels[i] class_correct[label] c[i].item() class_total[label] 1 for i in range(10): print(fAccuracy of {classes[i]:5s}: {100 * class_correct[i] / class_total[i]:.2f}%)你会发现某些类别如汽车、船舶的准确率较高而其他类别如猫、狗可能较低这是因为动物类别的图片通常有更多变化。6. 总结与改进建议通过本教程我们完成了一个完整的PyTorch图像分类项目从数据加载到模型训练和评估。虽然我们的简单模型已经能达到75%的准确率但还有很大的提升空间更深的网络结构尝试ResNet、EfficientNet等现代架构数据增强添加随机裁剪、翻转等增强方法学习率调度使用学习率衰减策略更长时间训练增加epoch数量正则化技术添加Dropout或权重衰减PyTorch 2.8提供了丰富的工具和优化使得深度学习模型的开发和训练变得更加高效。你可以基于这个基础项目继续探索更复杂的计算机视觉任务。获取更多AI镜像想探索更多AI镜像和应用场景访问 CSDN星图镜像广场提供丰富的预置镜像覆盖大模型推理、图像生成、视频生成、模型微调等多个领域支持一键部署。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2470396.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!