别再让AI模型‘学新忘旧’了:手把手教你用PyTorch搞定Continual Learning的灾难性遗忘
别再让AI模型‘学新忘旧’了手把手教你用PyTorch搞定Continual Learning的灾难性遗忘当你的猫狗分类模型刚学会识别虹猫蓝兔中的虹猫却突然忘记了普通家猫的样子——这就是典型的灾难性遗忘现象。作为算法工程师我们需要的不是每次遇到新数据就重新训练的笨模型而是能像人类一样持续积累知识的智能系统。本文将用PyTorch带你实现三种应对策略从最基础的Replay Buffer到最新的梯度约束方法。1. 环境准备与数据模拟首先需要构建一个能模拟真实场景的非独立同分布(Non-IID)数据集。我们以CIFAR-10为例将其拆分为5个连续任务import torch from torchvision import datasets, transforms from torch.utils.data import DataLoader, Subset # 定义数据转换 transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # 加载完整数据集 full_dataset datasets.CIFAR10(root./data, trainTrue, downloadTrue, transformtransform) # 按类别划分任务每个任务2个类别 task_classes [[0,1], [2,3], [4,5], [6,7], [8,9]] task_datasets [] for classes in task_classes: idx [i for i, (_, label) in enumerate(full_dataset) if label in classes] task_datasets.append(Subset(full_dataset, idx))这种任务划分方式模拟了现实场景中数据按批次到达的特点。每个任务只包含部分类别且后续任务不会重复之前见过的类别数据。2. 基准模型与灾难性遗忘验证我们先实现一个简单的CNN基准模型观察其在连续学习中的表现import torch.nn as nn import torch.optim as optim class SimpleCNN(nn.Module): def __init__(self, num_classes10): super(SimpleCNN, self).__init__() self.features nn.Sequential( nn.Conv2d(3, 32, kernel_size3, padding1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, kernel_size3, padding1), nn.ReLU(), nn.MaxPool2d(2) ) self.classifier nn.Linear(64 * 8 * 8, num_classes) def forward(self, x): x self.features(x) x x.view(x.size(0), -1) return self.classifier(x)训练过程中我们会发现模型在新任务上的准确率迅速提升但在旧任务上的表现急剧下降任务顺序任务1准确率任务2准确率任务3准确率训练前随机猜测随机猜测随机猜测任务1后89.2%--任务2后32.1%85.7%-任务3后18.5%41.2%82.3%这种性能断崖式下跌正是灾难性遗忘的典型表现。接下来我们将实现三种解决方案。3. 解决方案一经验回放(Replay Buffer)经验回放是最直观的解决方案——保存部分旧数据与新数据一起训练。以下是PyTorch实现from collections import deque import random class ReplayBuffer: def __init__(self, buffer_size): self.buffer_size buffer_size self.buffer deque(maxlenbuffer_size) def add(self, data): for sample in data: self.buffer.append(sample) def sample(self, batch_size): return random.sample(self.buffer, min(len(self.buffer), batch_size)) # 在训练循环中使用 buffer ReplayBuffer(1000) # 存储1000个样本 for task_id, task_data in enumerate(task_datasets): # 添加新任务数据到缓冲区 buffer.add(task_data) # 训练时混合新旧数据 optimizer optim.Adam(model.parameters()) for epoch in range(10): # 获取当前任务数据 current_loader DataLoader(task_data, batch_size32, shuffleTrue) # 从缓冲区采样旧数据 old_data buffer.sample(256) # 采样256个旧样本 old_loader DataLoader(old_data, batch_size32) # 混合训练 for (new_x, new_y), (old_x, old_y) in zip(current_loader, old_loader): optimizer.zero_grad() loss criterion(model(new_x), new_y) criterion(model(old_x), old_y) loss.backward() optimizer.step()这种方法虽然简单但存在两个主要问题存储旧数据可能违反数据隐私要求当旧任务很多时缓冲区可能无法保存足够代表性的样本4. 解决方案二弹性权重固化(EWC)EWC通过约束重要参数的更新来保护旧知识。以下是实现关键步骤def compute_fisher_matrix(model, dataset, num_samples1000): fisher {} for name, param in model.named_parameters(): fisher[name] torch.zeros_like(param.data) loader DataLoader(dataset, batch_size1, shuffleTrue) model.eval() for i, (x, y) in enumerate(loader): if i num_samples: break model.zero_grad() output model(x) loss criterion(output, y) loss.backward() for name, param in model.named_parameters(): fisher[name] param.grad.data ** 2 / num_samples return fisher # 在训练新任务时添加EWC约束 def ewc_loss(model, fisher, prev_params, lambda_500): loss 0 for name, param in model.named_parameters(): loss (fisher[name] * (param - prev_params[name]) ** 2).sum() return lambda_ * loss # 训练循环 prev_params {n: p.clone().detach() for n, p in model.named_parameters()} fisher compute_fisher_matrix(model, old_task_data) for epoch in range(10): for x, y in current_task_loader: optimizer.zero_grad() output model(x) loss criterion(output, y) ewc_loss(model, fisher, prev_params) loss.backward() optimizer.step()EWC的关键在于计算Fisher信息矩阵识别重要参数在新任务训练时惩罚重要参数的剧烈变化超参数λ控制约束强度通常500-10005. 解决方案三梯度投影约束(GPM)GPM是较新的方法通过约束梯度方向来避免遗忘class GPMLayer(nn.Module): def __init__(self, input_dim, output_dim): super().__init__() self.weight nn.Parameter(torch.randn(output_dim, input_dim)) self.bias nn.Parameter(torch.zeros(output_dim)) self.memory None # 用于存储重要梯度方向 def project_grad(self): if self.memory is not None: for direction in self.memory: # 计算当前梯度在重要方向上的分量 grad self.weight.grad.view(-1) proj (grad direction) * direction # 移除可能干扰旧知识的梯度分量 self.weight.grad.data - proj.view_as(self.weight) def forward(self, x): return nn.functional.linear(x, self.weight, self.bias) # 在训练过程中记录重要梯度方向 def record_important_directions(model, dataloader): model.train() for x, y in dataloader: model.zero_grad() output model(x) loss criterion(output, y) loss.backward() for name, module in model.named_modules(): if isinstance(module, GPMLayer): grad module.weight.grad.view(-1) if module.memory is None: module.memory [grad / grad.norm()] else: # 只保留与现有方向正交的新方向 new_dir grad.clone() for d in module.memory: new_dir - (new_dir d) * d if new_dir.norm() 0.5: # 阈值过滤 module.memory.append(new_dir / new_dir.norm())GPM的优势在于不需要存储原始数据自动识别并保护对旧任务重要的参数空间方向计算开销小于EWC6. 综合对比与实战建议三种方法的性能对比如下方法准确率保持计算开销内存需求实现难度Replay Buffer★★★★☆★★☆☆☆★★★★★★★☆☆☆EWC★★★☆☆★★★★☆★★☆☆☆★★★★☆GPM★★★★☆★★★☆☆★★★☆☆★★★★★在实际项目中我的经验是当数据隐私要求不高且有足够存储时优先使用Replay Buffer对计算资源有限的项目EWC是较好的折中方案当任务数量很多且关系复杂时GPM表现更优一个实用的技巧是在模型中加入一个小型验证集定期测试所有旧任务的表现。当发现某个旧任务准确率下降超过阈值时可以触发针对性的复习训练。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2563637.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!