ResNet50实战:用Fruits-360数据集训练自己的水果分类模型(附完整代码)
ResNet50实战用Fruits-360数据集训练自己的水果分类模型附完整代码在计算机视觉领域图像分类是最基础也最实用的任务之一。无论是工业质检、医疗影像分析还是零售商品识别都需要可靠的分类模型作为支撑。而水果分类作为典型的细粒度图像识别问题对模型的特征提取能力提出了更高要求。本文将带您从零开始使用经典的ResNet50架构和Fruits-360数据集构建一个专业级的水果分类系统。1. 环境准备与数据探索1.1 开发环境配置首先需要搭建支持GPU加速的深度学习环境。推荐使用conda创建独立的Python环境conda create -n fruit_classifier python3.8 conda activate fruit_classifier pip install torch torchvision pillow pandas tqdm matplotlib对于硬件配置建议至少满足GPUNVIDIA GTX 1060 6GB或更高内存16GB以上存储空间至少20GB可用空间数据集解压后约14GB1.2 Fruits-360数据集解析Fruits-360是一个专业的水果蔬菜图像数据集包含131种不同类别的水果和蔬菜超过8.3万张高质量图像统一背景所有样本都在白色背景下拍摄多角度拍摄每个水果都有不同角度的照片数据集目录结构如下Fruits-360/ ├── Training/ │ ├── Apple Braeburn/ │ ├── Apple Crimson Snow/ │ └── ...其他类别 └── Test/ ├── Apple Braeburn/ ├── Apple Crimson Snow/ └── ...其他类别提示数据集可从Kaggle直接下载解压后确保Training和Test目录位于同一父目录下2. 数据预处理与增强2.1 自定义数据加载器使用PyTorch的Dataset类创建自定义数据加载器from torchvision import transforms from torch.utils.data import Dataset from PIL import Image import os class FruitsDataset(Dataset): def __init__(self, root_dir, transformNone, trainTrue): self.root_dir os.path.join(root_dir, Training if train else Test) self.transform transform self.classes sorted(os.listdir(self.root_dir)) self.class_to_idx {cls: i for i, cls in enumerate(self.classes)} self.images self._load_images() def _load_images(self): images [] for cls in self.classes: cls_dir os.path.join(self.root_dir, cls) for img_name in os.listdir(cls_dir): img_path os.path.join(cls_dir, img_name) images.append((img_path, self.class_to_idx[cls])) return images def __len__(self): return len(self.images) def __getitem__(self, idx): img_path, label self.images[idx] image Image.open(img_path).convert(RGB) if self.transform: image self.transform(image) return image, label2.2 数据增强策略针对水果分类任务设计以下增强策略train_transform transforms.Compose([ transforms.Resize(256), transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness0.2, contrast0.2, saturation0.2), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) val_transform transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])关键增强技术说明增强技术参数设置作用RandomResizedCrop224x224模拟不同拍摄距离ColorJitterbrightness0.2应对光照变化RandomRotation15度增强角度不变性3. ResNet50模型定制3.1 模型架构调整加载预训练ResNet50并修改最后一层import torch.nn as nn from torchvision import models class FruitResNet(nn.Module): def __init__(self, num_classes131): super(FruitResNet, self).__init__() self.base_model models.resnet50(pretrainedTrue) in_features self.base_model.fc.in_features self.base_model.fc nn.Linear(in_features, num_classes) def forward(self, x): return self.base_model(x)3.2 迁移学习策略采用分阶段训练方法冻结阶段只训练最后的全连接层微调阶段解冻所有层进行整体微调def set_parameter_requires_grad(model, feature_extracting): if feature_extracting: for param in model.parameters(): param.requires_grad False model FruitResNet(num_classes131) set_parameter_requires_grad(model, feature_extractingTrue) # 仅优化最后一层 optimizer torch.optim.Adam(model.fc.parameters(), lr0.001)4. 模型训练与评估4.1 训练循环实现完整的训练流程包含以下关键组件from tqdm import tqdm def train_model(model, dataloaders, criterion, optimizer, num_epochs25): best_acc 0.0 for epoch in range(num_epochs): print(fEpoch {epoch}/{num_epochs-1}) print(- * 10) for phase in [train, val]: if phase train: model.train() else: model.eval() running_loss 0.0 running_corrects 0 for inputs, labels in tqdm(dataloaders[phase]): inputs inputs.to(device) labels labels.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase train): outputs model(inputs) _, preds torch.max(outputs, 1) loss criterion(outputs, labels) if phase train: loss.backward() optimizer.step() running_loss loss.item() * inputs.size(0) running_corrects torch.sum(preds labels.data) epoch_loss running_loss / len(dataloaders[phase].dataset) epoch_acc running_corrects.double() / len(dataloaders[phase].dataset) print(f{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}) if phase val and epoch_acc best_acc: best_acc epoch_acc torch.save(model.state_dict(), best_model.pth) return model4.2 学习率调度策略采用余弦退火学习率调度from torch.optim.lr_scheduler import CosineAnnealingLR optimizer torch.optim.Adam(model.parameters(), lr0.001) scheduler CosineAnnealingLR(optimizer, T_max10, eta_min1e-5)典型训练过程参数配置参数值说明Batch Size32平衡内存和稳定性初始学习率0.001微调常用初始值Epochs50包含冻结和解冻阶段权重衰减1e-4防止过拟合5. 模型部署与优化5.1 模型量化与加速使用TorchScript导出优化后的模型model.eval() example_input torch.rand(1, 3, 224, 224).to(device) traced_script_module torch.jit.trace(model, example_input) traced_script_module.save(fruit_classifier.pt)5.2 构建推理API简单的Flask服务端实现from flask import Flask, request, jsonify from PIL import Image import io app Flask(__name__) model torch.jit.load(fruit_classifier.pt, map_locationcpu) model.eval() app.route(/predict, methods[POST]) def predict(): if file not in request.files: return jsonify({error: no file uploaded}), 400 file request.files[file].read() image Image.open(io.BytesIO(file)).convert(RGB) image val_transform(image).unsqueeze(0) with torch.no_grad(): output model(image) _, predicted torch.max(output, 1) return jsonify({class: classes[predicted.item()]}) if __name__ __main__: app.run(host0.0.0.0, port5000)5.3 性能优化技巧实际部署时可考虑以下优化TensorRT加速将模型转换为TensorRT引擎ONNX导出实现跨平台部署量化压缩8位整数量化减小模型体积缓存机制对高频访问类别实现结果缓存在NVIDIA T4 GPU上的性能对比优化方式推理延迟(ms)模型大小(MB)原始模型45.298.7TensorRT12.683.4INT8量化8.324.96. 常见问题与解决方案6.1 类别不平衡处理Fruits-360中各类别样本数量差异较大可采用以下策略from torch.utils.data import WeightedRandomSampler class_counts [len(os.listdir(fFruits-360/Training/{cls})) for cls in classes] class_weights 1. / torch.tensor(class_counts, dtypetorch.float) sample_weights class_weights[labels] sampler WeightedRandomSampler( weightssample_weights, num_sampleslen(sample_weights), replacementTrue )6.2 过拟合应对方案当验证集准确率停滞时可尝试增加正则化optimizer torch.optim.Adam(model.parameters(), lr0.001, weight_decay1e-4)早停机制patience 5 best_acc 0.0 epochs_no_improve 0 if val_acc best_acc: best_acc val_acc epochs_no_improve 0 else: epochs_no_improve 1 if epochs_no_improve patience: print(Early stopping!) break标签平滑criterion nn.CrossEntropyLoss(label_smoothing0.1)6.3 模型解释性分析使用Grad-CAM可视化模型关注区域from torchcam.methods import GradCAM cam_extractor GradCAM(model, base_model.layer4.2) with torch.no_grad(): out model(input_tensor) activation_map cam_extractor(out.squeeze(0).argmax().item(), out) # 叠加原始图像 result overlay_mask( to_pil_image(input_tensor.squeeze(0)), to_pil_image(activation_map[0].squeeze(0), modeF), alpha0.5 )在实际项目中我们发现模型对水果的纹理特征如苹果的条纹和形状轮廓如香蕉的弯曲度最为敏感。通过可视化分析可以验证模型是否学习了有意义的特征而非依赖背景等无关信息。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2461544.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!