EMA滑动平均训练方式

news2026/4/2 12:35:49

1. EMA 介绍

首先该类实现, 使用timm ==0.6.11 版本;

Exponential Moving Average (EMA) for models in PyTorch.
目的:它旨在维护模型状态字典的移动平均值,包括参数和缓冲区。该技术通常用于训练方案,其中权重的平滑版本对于最佳性能至关重要。

1.1 v1 版本


class ModelEma:
    """ Model Exponential Moving Average (DEPRECATED)

    Keep a moving average of everything in the model state_dict (parameters and buffers).
    This version is deprecated, it does not work with scripted models. Will be removed eventually.

    This is intended to allow functionality like
    https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage

    A smoothed version of the weights is necessary for some training schemes to perform well.
    E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
    RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
    smoothing of weights to match results. Pay attention to the decay constant you are using
    relative to your update count per epoch.

    To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
    disable validation of the EMA weights. Validation will have to be done manually in a separate
    process, or after the training stops converging.

    This class is sensitive where it is initialized in the sequence of model init,
    GPU assignment and distributed training wrappers.
    """
    def __init__(self, model, decay=0.9999, device='', resume=''):
        # make a copy of the model for accumulating moving average of weights
        self.ema = deepcopy(model)
        self.ema.eval()
        self.decay = decay
        self.device = device  # perform ema on different device from model if set
        if device:
            self.ema.to(device=device)
        self.ema_has_module = hasattr(self.ema, 'module')
        if resume:
            self._load_checkpoint(resume)
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def _load_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        assert isinstance(checkpoint, dict)
        if 'state_dict_ema' in checkpoint:
            new_state_dict = OrderedDict()
            for k, v in checkpoint['state_dict_ema'].items():
                # ema model may have been wrapped by DataParallel, and need module prefix
                if self.ema_has_module:
                    name = 'module.' + k if not k.startswith('module') else k
                else:
                    name = k
                new_state_dict[name] = v
            self.ema.load_state_dict(new_state_dict)
            _logger.info("Loaded state_dict_ema")
        else:
            _logger.warning("Failed to find state_dict_ema, starting from loaded model weights")

    def update(self, model):
        # correct a mismatch in state dict keys
        needs_module = hasattr(model, 'module') and not self.ema_has_module
        with torch.no_grad():
            msd = model.state_dict()
            for k, ema_v in self.ema.state_dict().items():
                if needs_module:
                    k = 'module.' + k
                model_v = msd[k].detach()
                if self.device:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)


Methods:方法:

__init__:通过创建所提供模型的副本、设置衰减率和设备放置来初始化 EMA 模型。模型设置为评估模式,并且其梯度被禁用。

_load_checkpoint :加载 EMA 模型的检查点。它处理由 DataParallel 包装器引起的状态字典命名约定中的潜在差异。

update
通过计算原始模型参数和当前 EMA 参数的加权平均值来更新 EMA 参数。

Features:特征:

  1. 可以为模型及其 EMA 对应项指定不同的设备。
  2. 处理由于 DataParallel 包装器导致的状态字典键不匹配。
  3. 由于与脚本模型不兼容v1版本被弃用

1.2 v2 版本

import logging
from collections import OrderedDict
from copy import deepcopy

import torch
import torch.nn as nn

_logger = logging.getLogger(__name__)

class ModelEmaV2(nn.Module):
    """ Model Exponential Moving Average V2

    Keep a moving average of everything in the model state_dict (parameters and buffers).
    V2 of this module is simpler, it does not match params/buffers based on name but simply
    iterates in order. It works with torchscript (JIT of full model).

    This is intended to allow functionality like
    https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage

    A smoothed version of the weights is necessary for some training schemes to perform well.
    E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
    RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
    smoothing of weights to match results. Pay attention to the decay constant you are using
    relative to your update count per epoch.

    To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
    disable validation of the EMA weights. Validation will have to be done manually in a separate
    process, or after the training stops converging.

    This class is sensitive where it is initialized in the sequence of model init,
    GPU assignment and distributed training wrappers.
    """
    def __init__(self, model, decay=0.9999, device=None):
        super(ModelEmaV2, self).__init__()
        # make a copy of the model for accumulating moving average of weights
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.device = device  # perform ema on different device from model if set
        if self.device is not None:
            self.module.to(device=device)

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model): # 使用衰减率更新 EMA 参数
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):  # 直接将 EMA 参数设置为与提供的模型参数相同。
        self._update(model, update_fn=lambda e, m: m)

EmaV2版本:与 ModelEma 类似,但实现更简单。它还维护模型状态字典的移动平均值,并设计为与 torchscript(完整模型的 JIT)配合使用。

Methods:方法:

__init__:与 ModelEma 类似,但添加了对 super() 的调用来初始化 nn.Module 基类。

_update :更新 EMA 参数的辅助函数,以自定义更新函数作为参数。

update :使用衰减率更新 EMA 参数。

set :直接将 EMA 参数设置为与提供的模型参数相同。

Features:特征:

  1. 比 ModelEma 更简单、更直接的实现。
  2. 与torchscipt兼容。
  3. 根据参数的顺序而不是名称来匹配参数。
  • v1 版本与 v2版本之间的差异
    Differences差异:
  1. 设计复杂性: ModelEmaV2 更简单、更直接,避免了按名称匹配参数。

  2. 兼容性: ModelEmaV2 与 torchscript 兼容,与 ModelEma 不同。

  3. .参数匹配: ModelEma 按名称匹配参数和缓冲区,而 ModelEmaV2 根据参数和顺序进行匹配。

  4. 版本控制和用例: ModelEma 已被弃用,并且对于较新的训练方案(尤其是需要脚本的训练方案)而言不太受欢迎。

  5. 这两个类本质上用于相同的目的,但采用不同的方法,使得 ModelEmaV2 更适合利用脚本的现代 PyTorch 工作流程。

2. 使用方法

与 ModelEma 相比,在训练过程中使用 ModelEmaV2 涉及的方法略有不同。以下是有关如何将 ModelEmaV2 合并到训练循环中的指南,以及有关衰减参数的作用和预训练权重的使用的说明。

要在训练过程中使用 ModelEma V2 ,您应该将其集成到现有的训练循环中。以下是有关如何执行此操作的分步指南:

由于v1版本被弃用, 所以这里介绍使用 V2 版本;

2.1 初始化ema 类

初始化:定义模型后,使用您的模型作为参数初始化 ModelEmaV2 。根据您的需求设置 decay 参数。

model = YourModel()  # Replace with your model
ema = ModelEmaV2(model, decay=0.9999)
  • 设备配置:如果使用 GPU 等特定设备,请确保您的模型和 EMA 模型都移至该设备。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
ema.module.to(device)
  • 训练循环:在训练循环中,在每个反向传播步骤后更新 EMA 模型。

这里需要注意到的是 ,需要在每个反向传播 更新之后,才回去更新EMA 模型;

for epoch in range(num_epochs):
    for batch in dataloader:
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        ema.update(model)
        
  • 验证:使用EMA模型进行验证。由于平均权重,通常更适合预测。

在获取EMA 更新的权重之后,
EMA 模型的参数权重, 真正使用他的地方是在 推理阶段, 即 training 之后的 evaluate 阶段;

ema.module.eval()  # Set EMA model to evaluation mode
with torch.no_grad():
    for batch in validation_dataloader:
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = ema.module(inputs)  # Use EMA model for predictions
        # Compute validation metrics

  • 检查点:保存常规模型和 EMA 模型的状态字典。
torch.save({
    'model_state_dict': model.state_dict(),
    'ema_state_dict': ema.module.state_dict(),
    # ... other states like optimizer, epoch, etc.
}, 'checkpoint.pth')

  • 恢复训练:要从检查点恢复,请加载两个状态字典。
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
ema.module.load_state_dict(checkpoint['ema_state_dict'])
# Load other states

2.2 decay 参数的影响

ModelEmaV2 中的衰减参数起着至关重要的作用:

它确定移动平均线中当前模型参数相对于历史参数的权重。

  • 较高的衰减值(接近 1)赋予历史参数更大的权重,从而导致 EMA 模型权重的更新更平滑且更慢。
  • 较低的衰减值使 EMA 模型的权重对模型参数的近期变化更加敏感。

衰减值的选择取决于您的训练动态和训练步骤总数。常见的做法是从高衰减开始,然后随着时间的推移逐渐减少。

  • decay 参数;
    较高的衰减值(接近 1):当衰减参数设置为接近 1 时,EMA 模型会为较旧的(历史)参数赋予更多权重,而为最近更新的参数赋予较少权重。这使得 EMA 权重随着时间的推移变得更加平滑和更加稳定。平均权重响应新数据的变化更慢,这有利于减少噪声更新的影响。

较低的衰减值(远离 1):较低的衰减值导致 EMA 模型更加重视最近的模型更新。这使得 EMA 权重不太平滑,因为它们对模型参数的最新变化更加敏感。虽然这可以使 EMA 权重对数据的新趋势更加敏感,但也使它们更容易受到噪音和突然变化的影响。

总而言之,较高的衰减参数(接近 1)通过赋予历史数据更多权重来提高 EMA 模型权重的平滑度,从而导致权重更稳定但响应性较差。相反,较低的衰减值会降低平滑度,使权重对最近的变化更加敏感,但会牺牲稳定性。适当衰减值的选择取决于训练过程的具体要求和数据的性质。

使用 ModelEmaV2 时,在初始化 ModelEmaV2 之前将预训练的权重加载到原始模型中可能会很有帮助,特别是当您正在进行微调或有特定的起点时。

2.3 预训练权重

使用预先训练的权重:

  • 使用 ModelEmaV2 时,在初始化 ModelEmaV2 之前将预训练的权重加载到原始模型中可能会很有帮助,特别是当您正在进行微调或有特定的起点时。

  • 然后,EMA 模型将从这些权重的平滑版本开始,这可以导致更快的收敛和可能更好的最终性能,特别是在微调场景中。

  • 但是,如果您从头开始训练,则使用没有预训练权重的模型初始化 ModelEmaV2 也可以。 EMA 模型将随着训练的进展进行调整。

  • 总之, ModelEmaV2 用于维持模型权重的更平滑、更稳定的版本,这对于实现最佳性能至关重要,特别是在训练的后期阶段或微调场景中。衰减参数是控制应用平滑程度的关键。使用 ModelEmaV2 时,预训练权重可能很有用,但它们并不是绝对必要的,特别是在从头开始训练的场景中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1199531.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构与算法 | 第四章:字符串

本文参考网课为 数据结构与算法 1 第四章字符串,主讲人 张铭 、王腾蛟 、赵海燕 、宋国杰 、邹磊 、黄群。 本文使用IDE为 Clion,开发环境 C14。 更新:2023 / 11 / 12 数据结构与算法 | 第四章:字符串 字符串概念字符串字符字符…

Unbuntu安装、测试和卸载gcc11

GCC 可用于编译 C、C,本文介绍如何 Ubuntu 上安装 gcc11、测试和卸载它。 1. 在Ubuntu 上安装 gcc11 添加工具链存储库 sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test在 Ubuntu 上安装 gcc11 sudo apt install -y gcc-11验证 gcc11 版本 gcc-11 --v…

如何提升管理组织能力?

组织能力能力属于管理能力中的一部分,所以也称之为管理组织能力,组织是将人和事物的组合,有效的梳理和导向结果的能力。每个人都有组织能力,只是能力和效率上存在较大的差异。 一人的组织能力从学生时代就能体现出来,…

springboot高校全流程考勤系统-计算机毕设 附源码 27637

Springboot高校全流程考勤系统 摘 要 本文针对高校考勤等问题,对其进行研究分析,然后开发设计出高校全流程考勤系统以解决问题。高校全流程考勤系统系统主要功能模块包括:考勤签到、课程信息、考勤情况、申请记录列表等,系统功能设…

财务报告是什么

财务报告是什么 财务报告是企业对外提供的反映企业某一特定日期的财务状况和某一会计期间的经营成果、现金流量等会计信息的文件。 根据财务报告的定义,财务报告具有以下几层含义:一是财务报告应当是对外报告,其服务对象主要是投资者、债权人…

yolo系列报错(持续补充ing)

文章目录 export GIT_PYTHON_REFRESHquiet解决 没有pt权重文件解决 python文件路径报错解决 读取文件列名报错解决 导入不同文件夹出错解决 megengine没有安装解决然后你发现它竟然还没有用 export GIT_PYTHON_REFRESHquiet 设置环境变量 GIT_PYTHON_REFRESH ,这个…

postgresql实现job的六种方法

简介 在postgresql数据库中并没有想oracle那样的job功能,要想实现job调度,就需要借助于第三方。本人更为推荐kettle,pgagent这样的图形化界面,对于开发更为友好 优势劣势Linux 定时任务(crontab) 简单易用…

【强化学习】18 —— SAC( Soft Actor-Critic)

文章目录 前言最大熵强化学习不同动作空间下的最大熵强化学习基于能量的模型软价值函数最大熵策略 Soft Q-learningSoft Q-IterationSoft Q-Learning近似采样与SVGD伪代码 Soft Actor-Critic伪代码代码实践连续动作空间离散动作空间 参考与推荐 前言 之前的章节提到过在线策略…

macOS Big Sur(macos11版本)

macOS Big Sur是苹果推出的最新操作系统,具有以下特点: 全新的设计风格:Big Sur采用了全新的设计语言,包括更加圆润的窗口和控件、更加鲜明的色彩和更加简洁的界面。这种设计风格使得操作系统更加美观和易用。强大的性能表现&…

国际阿里云:云服务器灾备方案!!!

保障企业业务稳定、IT系统功能正常、数据安全十分重要,可以同时保障数据备份与系统、应用容灾的灾备解决方案应势而生,且发展迅速。ECS可使用快照、镜像进行备份。 灾备设计 快照备份 阿里云ECS可使用快照进行系统盘、数据盘的备份。目前,阿…

Ubuntu取消sudo的输入密码

Ubuntu最近要安装软件,每次sudo都要输入一次密码,感觉很麻烦,于是想能不能设置为不输入密码,在网上找了一下解决办法。 主要参考这篇文章: Ubuntu取消sudo时输入密码 上面这篇文章使用的是vim,但是按照博…

Kotlin基础——接口和类

接口 使用 : 表示继承关系&#xff0c;只能继承一个类&#xff0c;但可以实现多个接口override修饰符表示重写可以有默认方法&#xff0c;若父类的默认方法冲突&#xff0c;则需要子类重写&#xff0c;使用super<XXX>.xxx()调用某一父类方法 interface Focusable {fun …

免费录屏软件哪个好用?免费录屏软件排行榜

对于您的团队&#xff0c;屏幕录像机可以用于多种原因——从为您的网站创建教程到记录反复出现的技术问题&#xff0c;再到向您的营销团队发送快速说明而不是电子邮件。 此外&#xff0c;我们不能忘记产品演示和培训视频&#xff0c;它们可供您团队中的许多部门使用&#xff0…

matlab模糊控制文件m代码实现和基础理论

1、内容简介 略 15-可以交流、咨询、答疑 通过m代码来实现生成模糊文件fis文件 2、内容说明 模糊文件m代码实现和基础理论 matlab模糊控制文件m代码实现和基础理论 模糊文件、m代码和模糊基础理论 3、仿真分析 略 4、参考论文 略 链接&#xff1a;https://pan.baidu.co…

web前端开发第3次Dreamweave课堂练习/html练习代码《网页设计语言基础练习案例》

目标图片&#xff1a; 文字素材&#xff1a; 网页设计语言基础练习案例 ——几个从语义上和文字相关的标签 * h标签&#xff08;h1~h6&#xff09;&#xff1a;用来定义网页的标题&#xff0c;成对出现。 * p标签&#xff1a;用来设置网页的段落&#xff0c;成对出现。 * b…

NZ系列工具NZ08:图表添加标签工具

我的教程一共九套及VBA汉英手册一部&#xff0c;分为初级、中级、高级三大部分。是对VBA的系统讲解&#xff0c;从简单的入门&#xff0c;到数据库&#xff0c;到字典&#xff0c;到高级的网抓及类的应用。大家在学习的过程中可能会存在困惑&#xff0c;这么多知识点该如何组织…

【车载开发系列】AutoSar中的CANTP

【车载开发系列】AutoSar中的CANTP 【车载开发系列】AutoSar中的CANTP 【车载开发系列】AutoSar中的CANTP一. CANTP相关术语二. CANTP相关概念1&#xff09;单帧&#xff1a;SF(Single Frame)2&#xff09;首帧&#xff1a;FF(First Frame)3&#xff09;连续帧CF(Consecutive F…

腾讯云3年轻量2核4G5M服务器756元,抓紧数量不多

腾讯云轻量应用服务器特价是有新用户限制的&#xff0c;所以阿腾云建议大家选择3年期轻量应用服务器&#xff0c;一劳永逸&#xff0c;免去续费困扰。腾讯云轻量应用服务器3年可以选择2核2G4M和2核4G5M带宽&#xff0c;3年轻量2核2G4M服务器540元&#xff0c;2核4G5M轻量应用服…

十五、信号量

1、概述 (1)前面介绍的队列(queue)可以用于传输数据&#xff1a;在任务之间、任务和中断之间。 (2)有些时候我们只需要传递状态&#xff0c;并不需要传递具体的信息&#xff0c;比如&#xff1a; 我的事做完了&#xff0c;通知一下你。卖包子了、卖包子了&#xff0c;做好了…

python 中用opencv开发虚拟键盘------可以只选择一个单词不会出现一下选择多个

一. 介绍 OpenCV是最流行的计算机视觉任务库&#xff0c;它是用于机器学习、图像处理等的跨平台开源库&#xff0c;用于开发实时计算机视觉应用程序。 CVzone 是一个计算机视觉包&#xff0c;它使用 OpenCV 和 Media Pipe 库作为其核心&#xff0c;使我们易于运行&#xff0c…