深度学习之模型压缩三驾马车:基于ResNet18的模型剪枝实战(1)

news2025/6/6 14:56:39

一、背景:为什么需要模型剪枝?

随着深度学习的发展,模型参数量和计算量呈指数级增长。以ResNet18为例,其在ImageNet上的参数量约为1100万,虽然在服务器端运行流畅,但在移动端或嵌入式设备上部署时,内存和计算资源的限制使得直接使用大模型变得困难。模型剪枝(Model Pruning)作为模型压缩的核心技术之一,通过删除冗余的神经元或通道,在保持模型性能的前提下显著降低模型大小和计算量,是解决这一问题的关键手段。
在前面一篇文章我们也提到了模型压缩的一些基本定义和核心原理:《深度学习之模型压缩三驾马车:模型剪枝、模型量化、知识蒸馏》。

本文将基于PyTorch框架,以ResNet18在CIFAR-10数据集上的分类任务为例,详细讲解结构化通道剪枝的完整实现流程,包括模型训练、剪枝策略、剪枝后结构调整、微调及效果评估。

二、整体流程概览

本文代码的核心流程可总结为以下6步:

  1. 环境初始化与数据集加载
  2. 原始模型训练与评估
  3. 卷积层结构化剪枝(以conv1层为例)
  4. 剪枝后模型结构调整(BN层、残差下采样层等)
  5. 剪枝模型微调
  6. 剪枝前后模型效果对比
    特地说明:在这里选择conv1层作为例子,不是因为选择这个就会效果更好。

三、关键步骤代码解析

3.1 环境初始化与数据集准备

首先需要配置计算设备(GPU/CPU),并加载CIFAR-10数据集。CIFAR-10包含10类32x32的彩色图像,训练集5万张,测试集1万张。

def setup_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_dataset():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # 归一化到[-1,1]
    ])
    train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
    return train_dataset, test_dataset

3.2 原始模型训练

使用预训练的ResNet18模型,修改全连接层输出为10类(匹配CIFAR-10的类别数),并进行5轮训练:

def create_model(device):
    model = models.resnet18(pretrained=True)  # 加载ImageNet预训练权重
    model.fc = nn.Linear(512, 10)  # 修改输出层为10类
    return model.to(device)

def train_model(model, train_loader, criterion, optimizer, device, epochs=3):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")
    return model

3.3 结构化通道剪枝核心实现

本文重点是对卷积层进行结构化剪枝(按通道剪枝),具体步骤如下:

3.3.1 计算通道重要性

通过计算卷积核的L2范数评估通道重要性。假设卷积层权重维度为[out_channels, in_channels, kernel_h, kernel_w],将每个输出通道的权重展平为一维向量,计算其L2范数,范数越小表示该通道对模型性能贡献越低,越应被剪枝。

layer = dict(model.named_modules())[layer_name]  # 获取目标卷积层
weight = layer.weight.data
channel_norm = torch.norm(weight.view(weight.shape[0], -1), p=2, dim=1)  # 计算每个输出通道的L2范数
3.3.2 生成剪枝掩码

根据剪枝比例(如20%),选择范数最小的通道生成掩码:

num_channels = weight.shape[0]  # 原始输出通道数(如ResNet18的conv1层为64)
num_prune = int(num_channels * amount)  # 需剪枝的通道数(如64*0.2=12)
_, indices = torch.topk(channel_norm, k=num_prune, largest=False)  # 找到最不重要的12个通道

mask = torch.ones(num_channels, dtype=torch.bool)
mask[indices] = False  # 掩码:保留的通道标记为True(52个),剪枝的标记为False(12个)
3.3.3 替换卷积层

创建新的卷积层,仅保留掩码为True的通道:

new_conv = nn.Conv2d(
    in_channels=layer.in_channels,
    out_channels=num_channels - num_prune,  # 剪枝后输出通道数(52)
    kernel_size=layer.kernel_size,
    stride=layer.stride,
    padding=layer.padding,
    bias=layer.bias is not None
).to(device)  # 移动到模型所在设备

new_conv.weight.data = layer.weight.data[mask]  # 保留掩码为True的通道权重
if layer.bias is not None:
    new_conv.bias.data = layer.bias.data[mask]  # 偏置同理
3.3.4 关键:剪枝后结构调整

直接剪枝会导致后续层(如BN层、残差连接中的下采样层)的输入/输出通道不匹配,必须同步调整:

(1) 调整BN层
卷积层后通常接BN层,BN的num_features需与卷积输出通道数一致:

if 'conv1' in layer_name:
    bn1 = model.bn1
    new_bn1 = nn.BatchNorm2d(new_conv.out_channels).to(device)  # 新BN层通道数52
    with torch.no_grad():
        # 同步原始BN层的参数(仅保留未被剪枝的通道)
        new_bn1.weight.data = bn1.weight.data[mask].clone()
        new_bn1.bias.data = bn1.bias.data[mask].clone()
        new_bn1.running_mean.data = bn1.running_mean.data[mask].clone()
        new_bn1.running_var.data = bn1.running_var.data[mask].clone()
    model.bn1 = new_bn1

(2) 调整残差下采样层
ResNet的残差块(如layer1.0)中,若主路径的通道数被剪枝,需要通过1x1卷积的下采样层(downsample)匹配 shortcut 的通道数:

block = model.layer1[0]
if not hasattr(block, 'downsample') or block.downsample is None:
    # 原始无downsample,创建新的1x1卷积+BN
    downsample_conv = nn.Conv2d(
        in_channels=new_conv.out_channels,  # 52(剪枝后的conv1输出)
        out_channels=block.conv2.out_channels,  # 64(主路径conv2的输出)
        kernel_size=1,
        stride=1,
        bias=False
    ).to(device)
    torch.nn.init.kaiming_normal_(downsample_conv.weight, mode='fan_out', nonlinearity='relu')  # 初始化权重
    
    downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels).to(device)
    block.downsample = nn.Sequential(downsample_conv, downsample_bn)  # 添加downsample层
else:
    # 原有downsample层,调整输入通道
    downsample_conv = block.downsample[0]
    downsample_conv.in_channels = new_conv.out_channels  # 输入通道改为52
    downsample_conv.weight = nn.Parameter(downsample_conv.weight.data[:, mask, :, :].clone())  # 输入通道用掩码筛选

(3) 前向传播验证
调整后需验证模型能否正常前向传播,避免通道不匹配导致的错误:

with torch.no_grad():
    test_input = torch.randn(1, 3, 32, 32).to(device)  # 测试输入(B, C, H, W)
    try:
        model(test_input)
        print("✅ 前向传播验证通过")
    except Exception as e:
        print(f"❌ 验证失败: {str(e)}")
        raise

3.3的总结,直接上代码

def prune_conv_layer(model, layer_name, amount=0.2):
    # 获取模型当前所在设备
    device = next(model.parameters()).device  # 新增:获取设备
    
    layer = dict(model.named_modules())[layer_name]
    weight = layer.weight.data
    channel_norm = torch.norm(weight.view(weight.shape[0], -1), p=2, dim=1)
    
    num_channels = weight.shape[0]  # 原始通道数(如 64)
    num_prune = int(num_channels * amount)
    _, indices = torch.topk(channel_norm, k=num_prune, largest=False)
    
    mask = torch.ones(num_channels, dtype=torch.bool)
    mask[indices] = False  # 生成剪枝掩码(长度 64,52 个 True)
    
    new_conv = nn.Conv2d(
        in_channels=layer.in_channels,
        out_channels=num_channels - num_prune,  # 剪枝后通道数(如 52)
        kernel_size=layer.kernel_size,
        stride=layer.stride,
        padding=layer.padding,
        bias=layer.bias is not None
    )
    new_conv = new_conv.to(device)  # 新增:移动到模型所在设备
    
    new_conv.weight.data = layer.weight.data[mask]  # 保留 mask 为 True 的通道
    if layer.bias is not None:
        new_conv.bias.data = layer.bias.data[mask]
    
    # 替换原始卷积层
    parent_name, sep, name = layer_name.rpartition('.')
    parent = model.get_submodule(parent_name)
    setattr(parent, name, new_conv)

    if 'conv1' in layer_name:
        # 1. 更新与 conv1 直接关联的 BN1 层
        bn1 = model.bn1
        new_bn1 = nn.BatchNorm2d(new_conv.out_channels)  # 新 BN 层通道数 52
        new_bn1 = new_bn1.to(device)  # 新增:移动到模型所在设备
        with torch.no_grad():
            new_bn1.weight.data = bn1.weight.data[mask].clone()
            new_bn1.bias.data = bn1.bias.data[mask].clone()
            new_bn1.running_mean.data = bn1.running_mean.data[mask].clone()
            new_bn1.running_var.data = bn1.running_var.data[mask].clone()
        model.bn1 = new_bn1

        # 2. 处理残差连接中的 downsample(关键修正:添加缺失的 downsample)
        block = model.layer1[0]
        if not hasattr(block, 'downsample') or block.downsample is None:
            # 原始无 downsample,需创建新的 1x1 卷积+BN 来匹配通道
            downsample_conv = nn.Conv2d(
                in_channels=new_conv.out_channels,  # 52
                out_channels=block.conv2.out_channels,  # 64(主路径输出通道数)
                kernel_size=1,
                stride=1,
                bias=False
            )
            downsample_conv = downsample_conv.to(device)  # 新增:移动到模型所在设备
            # 初始化 1x1 卷积权重(这里简单复制原模型可能的统计量,实际可根据需求调整)
            torch.nn.init.kaiming_normal_(downsample_conv.weight, mode='fan_out', nonlinearity='relu')
            
            downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels)
            downsample_bn = downsample_bn.to(device)  # 新增:移动到模型所在设备
            with torch.no_grad():
                # 初始化 BN 参数(可保持默认,或根据原模型统计量调整)
                downsample_bn.weight.fill_(1.0)
                downsample_bn.bias.zero_()
                downsample_bn.running_mean.zero_()
                downsample_bn.running_var.fill_(1.0)
            
            block.downsample = nn.Sequential(downsample_conv, downsample_bn)
            print("✅ 为 layer1.0 添加新的 downsample 层")
        else:
            # 原有 downsample 层,调整输入通道
            downsample_conv = block.downsample[0]
            downsample_conv.in_channels = new_conv.out_channels  # 输入通道调整为 52
            downsample_conv.weight = nn.Parameter(downsample_conv.weight.data[:, mask, :, :].clone())  # 输入通道用 mask 筛选
            downsample_conv = downsample_conv.to(device)  # 新增:移动到模型所在设备
            
            downsample_bn = block.downsample[1]
            new_downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels)
            new_downsample_bn = new_downsample_bn.to(device)  # 新增:移动到模型所在设备
            with torch.no_grad():
                new_downsample_bn.weight.data = downsample_bn.weight.data.clone()
                new_downsample_bn.bias.data = downsample_bn.bias.data.clone()
                new_downsample_bn.running_mean.data = downsample_bn.running_mean.data.clone()
                new_downsample_bn.running_var.data = downsample_bn.running_var.data.clone()
            block.downsample[1] = new_downsample_bn

        # 3. 同步 layer1.0.conv1 的输入通道(保持原有逻辑)
        next_convs = ['layer1.0.conv1']
        for conv_path in next_convs:
            try:
                conv = model.get_submodule(conv_path)
                if conv.in_channels != new_conv.out_channels:
                    print(f"同步输入通道: {conv.in_channels}{new_conv.out_channels}")
                    conv.in_channels = new_conv.out_channels
                    conv.weight = nn.Parameter(conv.weight.data[:, mask, :, :].clone())
                    conv = conv.to(device)  # 新增:移动到模型所在设备
            except AttributeError as e:
                print(f"⚠️ 卷积层调整失败: {conv_path} ({str(e)})")

        # 验证前向传播
        with torch.no_grad():
            test_input = torch.randn(1, 3, 32, 32).to(device)  # 确保测试输入也在相同设备
            try:
                model(test_input)
                print("✅ 前向传播验证通过")
            except Exception as e:
                print(f"❌ 验证失败: {str(e)}")
                raise

    return model

3.4 剪枝模型微调

剪枝后模型的部分参数被删除,需要通过微调恢复性能。一开始,我们只是在微调时冻结了除 fc 层外的所有参数,但是效果并不好,当然分析原因,除了动了conv1的原因(conv1 是模型的第一个卷积层,负责提取最基础的图像特征(如边缘、纹理、颜色等)。这些底层特征对后续所有层的特征提取至关重要。),最重要的是裁剪后,需要对裁剪的层进行微调,确保参数适应新的特征维度。
微调时冻结了除 fc 层外的所有参数的代码和结果:

for name, param in pruned_model.named_parameters():
        if 'fc' not in name:
            param.requires_grad = False
    optimizer = optim.Adam(pruned_model.fc.parameters(), lr=0.001)
    print("微调剪枝后的模型")
    pruned_model = train_model(pruned_model, train_loader, criterion, optimizer, device,epochs=5)
原始模型准确率: 80.07%
剪枝后模型准确率: 37.80%

可以看到这个相差很大
本文选择解冻被剪枝的层(如conv1bn1)及相关层(如layer1.0.conv1downsample)进行参数更新:

print("开始微调剪枝后的模型")
for name, param in pruned_model.named_parameters():
    # 仅解冻与剪枝相关的层
    if 'conv1' in name or 'bn1' in name or 'layer1.0.conv1' in name or 'layer1.0.downsample' in name or 'fc' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False
optimizer = optim.Adam(filter(lambda p: p.requires_grad, pruned_model.parameters()), lr=0.001)
pruned_model = train_model(pruned_model, train_loader, criterion, optimizer, device, epochs=5)
原始模型准确率: 78.94%
剪枝后模型准确率:  81.30%

重新微调了裁剪后的层后,结果有了很大改变。

四、实验结果与分析

通过代码中的evaluate_model函数评估剪枝前后的模型准确率:

def evaluate_model(model, device, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = 100 * correct / total
    return acc

假设原始模型准确率为88.5%,剪枝20%通道后(模型大小降低约20%),通过微调可恢复至87.2%,验证了剪枝策略的有效性。

五、总结与改进方向

本文实现了基于通道L2范数的结构化剪枝,重点解决了剪枝后模型结构不一致的问题(如BN层、残差下采样层的调整),并通过微调恢复了模型性能。
在这个例子中,仅裁剪 conv1 层的影响
仅裁剪 conv1 层对模型的影响极大,原因如下:

  • 底层特征的重要性 : conv1 输出的是最基础的图像特征,所有后续层的特征均基于此生成。裁剪 conv1 会直接限制后续所有层的特征表达能力。
  • 结构连锁反应 : conv1 的输出通道减少会触发 bn1 、 layer1.0.conv1 、 downsample 等多个模块的调整,任何一个模块的调整失误(如通道数不匹配、参数初始化不当)都会导致整体性能下降。
    实际应用中可从以下方向改进:

模型裁剪通常优先选择 中间层(如ResNet的 layer2 、 layer3 ) ,而非底层或顶层,原因如下:

  • 底层(如 conv1 ) :负责基础特征提取,裁剪后特征损失大,对性能影响显著。
  • 中间层(如 layer2 、 layer3 ) :特征具有一定抽象性但冗余度高(同一层的多个通道可能提取相似特征),裁剪后对性能影响较小。
  • 顶层(如 fc 层) :负责分类决策,参数密度高但冗余度低,裁剪易导致分类能力下降。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2401783.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网络测试实战:金融数据传输的生死时速

阅读原文 7.4 网络测试实战--数据传输:当毫秒决定百万盈亏 你的交易指令为何总是慢人一步? 在2020年"原油宝"事件中,中行原油宝产品因为数据传输延迟导致客户未能及时平仓,最终亏损超过90亿元。这个血淋淋的案例揭示了…

数据库系统概论(十四)详细讲解SQL中空值的处理

数据库系统概论(十四)详细讲解SQL中空值的处理 前言一、什么是空值?二、空值是怎么产生的?1. 插入数据时主动留空2. 更新数据时设置为空3. 外连接查询时自然出现 三、如何判断空值?例子:查“漏填数据的学生…

【信创-k8s】海光/兆芯+银河麒麟V10离线部署k8s1.31.8+kubesphere4.1.3

❝ KubeSphere V4已经开源半年多,而且v4.1.3也已经出来了,修复了众多bug。介于V4优秀的LuBan架构,核心组件非常少,资源占用也显著降低,同时带来众多功能和便利性。我们决定与时俱进,使用1.30版本的Kubernet…

一台电脑联网如何共享另一台电脑?网线方式

前言 公司内网一个人只能申请一个账号和一个主机设备;会检测MAC地址;如果有两台设备,另一台就没有网;因为是联想老电脑,共享热点用不了,但是有一根网线,现在解决网线方式共享网络; …

MacroDroid安卓版:自动化操作,让生活更智能

在智能手机的日常使用中,我们常常会遇到一些重复性的任务,如定时开启或关闭Wi-Fi、自动回复消息、根据位置调整音量等。这些任务虽然简单,但频繁操作会让人感到繁琐。MacroDroid安卓版正是为了解决这些问题而设计的,它是一款功能强…

力提示(force prompting)的新方法

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

【Redis实战:缓存与消息队列的应用】

在现代互联网开发中,Redis 作为一款高性能的内存数据库,广泛应用于缓存和消息队列等场景。本文将深入探讨 Redis 在这两个领域的应用,并通过代码示例比较两个流行的框架(Redis 和 RabbitMQ)的特点与适用场景&#xff0…

实验设计与分析(第6版,Montgomery著,傅珏生译) 第10章拟合回归模型10.9节思考题10.12 R语言解题

本文是实验设计与分析&#xff08;第6版&#xff0c;Montgomery著&#xff0c;傅珏生译) 第10章拟合回归模型10.9节思考题10.12 R语言解题。主要涉及线性回归、回归的显著性、残差分析。 10-12 vial <- seq(1, 12, 1) Viscosity <- c(26,24,175,160,163,55,62,100,26,30…

告别局域网:实现NASCab云可云远程自由访问

文章目录 前言1. 检查NASCab本地端口2. Qindows安装Cpolar3. 配置NASCab远程地址4. 远程访问NASCab小结 5. 固定NASCab公网地址6. 固定地址访问NASCab 前言 在数字化生活日益普及的今天&#xff0c;拥有一个属于自己的私有云存储&#xff08;如NASCab云可云&#xff09;已成为…

Python实现markdown文件转word

1.markdown内容如下&#xff1a; 2.转换后的内容如下&#xff1a; 3.附上代码&#xff1a; import argparse import os from markdown import markdown from bs4 import BeautifulSoup from docx import Document from docx.shared import Inches from docx.enum.text import …

NLP学习路线图(十七):主题模型(LDA)

在浩瀚的文本海洋中航行&#xff0c;人类大脑天然具备发现主题的能力——翻阅几份报纸&#xff0c;我们迅速辨别出"政治"、"体育"、"科技"等板块&#xff1b;浏览社交媒体&#xff0c;我们下意识区分出美食分享、旅行见闻或科技测评。但机器如何…

综采工作面电控4X型铜头连接器 conm/4x100s

综采工作面作为现代化煤矿生产的核心区域&#xff0c;其设备运行的稳定性和安全性直接关系到整个矿井的生产效率。在综采工作面的电气控制系统中&#xff0c;电控连接器扮演着至关重要的角色&#xff0c;而4X型铜头连接器CONM/4X100S作为其中的关键部件&#xff0c;其性能优劣直…

用ApiFox MCP一键生成接口文档,做接口测试

日常开发过程中&#xff0c;尤其是针对长期维护的老旧项目&#xff0c;许多开发者都会遇到一系列相同的困扰&#xff1a;由于项目早期缺乏严格的开发规范和接口管理策略&#xff0c;导致接口文档缺失&#xff0c;甚至连基本的接口说明都难以找到。此外&#xff0c;由于缺乏规范…

在compose中的Canvas用kotlin显示多数据波形闪烁的问题

在compose中的Canvas显示多数据波形闪烁的问题&#xff1a;当在Canvas多组记录波形数组时&#xff0c;从第一组开始记录多次显示&#xff0c;如图&#xff0c;当再次回到第一次记录位置再显示时&#xff0c;波形出现闪烁。 原码如下&#xff1a; data class DcWaveForm(var b…

【学习笔记】MIME

文章目录 1. 引言2. MIME 构成Content-Type&#xff08;内容类型&#xff09;Content-Transfer-Encoding&#xff08;传输编码&#xff09;Multipart&#xff08;多部分&#xff09; 3. 常见 MIME 类型 1. 引言 早期的电子邮件只能发送 ASCII 文本&#xff0c;无法直接传输二进…

单北斗定位芯片AT9880B

AT9880B 是面向北斗卫星导航系统的单模接收机单芯片&#xff08;SOC&#xff09;&#xff0c;内部集成射频前端、数字基带处理单元、北斗多频信号处理引擎及电源管理模块&#xff0c;支持北斗二号与三号系统的 B1I、B1C、B2I、B3I、B2a、B2b 频点信号接收。 主要特征 支持北斗二…

旅游微信小程序制作指南

想创建旅游微信小程序吗&#xff1f;知道旅游业企业怎么打造自己的小程序吗&#xff1f;这里有零基础小白也能学会的教程&#xff0c;教你快速制作旅游类微信小程序&#xff01; 旅游行业能不能开发微信小程序呢&#xff1f;答案是肯定的。微信小程序对旅游企业来说可是个宝&am…

Ubuntu ifconfig 查不到ens33网卡

BUG&#xff1a;ifconfig查看网络配置信息&#xff1a; 终端输入以下命令&#xff1a; sudo service network-manager stop sudo rm /var/lib/NetworkManager/NetworkManager.state sudo service network-manager start - service network - manager stop &#xff1a;停止…

【python深度学习】Day 45 Tensorboard使用介绍

知识点&#xff1a; tensorboard的发展历史和原理tensorboard的常见操作tensorboard在cifar上的实战&#xff1a;MLP和CNN模型 效果展示如下&#xff0c;很适合拿去组会汇报撑页数&#xff1a; 作业&#xff1a;对resnet18在cifar10上采用微调策略下&#xff0c;用tensorboard监…

【图像处理入门】5. 形态学处理:腐蚀、膨胀与图像的形状雕琢

摘要 形态学处理是基于图像形状特征的处理技术,在图像分析中扮演着关键角色。本文将深入讲解腐蚀、膨胀、开闭运算等形态学操作的原理,结合OpenCV代码展示其在去除噪声、提取边缘、分割图像等场景的应用,带你掌握通过结构元素雕琢图像形状的核心技巧。 一、形态学处理:基…