2.10.批量归一化

news2025/6/16 18:22:25

批量归一化

​ 损失出现在最后,所以后面的层训练比较快,而数据在最底部,则:

  • 底部的层训练较慢
  • 底部层一变化,所有都会跟着变化
  • 最后的层需要重新学习多次

​ 最后导致收敛变慢。

​ 或许我们可以通过固定输出和梯度的特定分布,即均值和方差在一定范围内,来进行优化,以提高数据和损失的稳定性

1.批量归一化

​ 固定小批量里面的均值和方差
μ B = 1 ∣ B ∣ ∑ i ∈ B x i σ B 2 = 1 ∣ B ∣ ∑ i ∈ B ( x i − μ B ) 2 + ϵ ( 加一个小数值避免为 0 ) \mu_B=\frac {1}{|B|}\sum_{i\in B} x_i\\ \sigma^2_B = \frac {1}{|B|}\sum_{i\in B}(x_i -\mu _B)^2 +\epsilon (加一个小数值避免为0) μB=B1iBxiσB2=B1iB(xiμB)2+ϵ(加一个小数值避免为0)
​ 再做额外的调整(可学习的参数 γ ( 方差 ) , β ( 均值 ) \gamma(方差),\beta (均值) γ(方差),β(均值))
x i + 1 = B N ( x i ) = γ x i − μ b σ B + β x_{i+1} = BN(x_i) = \gamma\frac{x_i -\mu_b}{\sigma _B}+\beta xi+1=BN(xi)=γσBxiμb+β

2.批量归一化层

​ 可学习的参数 γ , β \gamma,\beta γ,β

​ 作用在全连接和卷积层的输出上,激活函数前;或全连接层和卷积层输入上

​ 对全连接层,作用再特征维;对卷积层,作用在通道维

​ 批量归一化是线性变换

2.1 全连接层

通常,我们将批量规范化层置于全连接层中的仿射变换和激活函数之间。 设全连接层的输入为x,权重参数和偏置参数分别为𝑊和𝑏,激活函数为𝜙,批量规范化的运算符为BN。 那么,使用批量规范化的全连接层的输出的计算详情如下:
h = ϕ ( B N ( W x + b ) ) h=\phi (BN(Wx+b)) h=ϕ(BN(Wx+b))

2.2 卷积层

​ 对于卷积层,我们可以在卷积层之后和非线性激活函数之前应用批量规范化。

​ 当卷积有多个输出通道时,我们需要对这些通道的“每个”输出执行批量规范化,每个通道都有自己的拉伸(scale)和偏移(shift)参数,这两个参数都是标量。

​ 假设我们的小批量包含𝑚个样本,并且对于每个通道,卷积的输出具有高度𝑝和宽度𝑞。 那么对于卷积层,我们在每个输出通道的𝑚⋅𝑝⋅𝑞个元素上同时执行每个批量规范化。

​ 因此,在计算平均值和方差时,我们会收集所有空间位置的值,然后在给定通道内应用相同的均值和方差,以便在每个空间位置对值进行规范化。

批量归一化到底在做什么

​ 可能是通过在每个小批量里加入噪音来控制模型复杂度,因此没必要和丢弃法混合使用(也只是可能)

在这里插入图片描述

μ B 、 σ B \mu_B、\sigma_B μBσB是每一批次的均值和方差,其实每一批次都不一样,比较随机。

批量归一化可以加速收敛速度,但一般不改变模型精度。学习率就可以比较大了

3.代码实现

import torch
from torch import nn
from d2l import torch as d2l

'''从零实现'''


def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # moving_mean和moving_var 可近似认为是全局上的均值和方差,eps是方差小数值
    # momentum 用于更新均值和方差,通常是0.9或一个固定的数字
    # 通过is_grad_enabled来判断当前模式是训练模式还是预测模式
    if not torch.is_grad_enabled():
        # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
        # 预测模式一般只有一个样本,所以需要用全局的均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # 使用全连接层的情况,计算特征维上的均值和方差
            mean = X.mean(dim=0)  # 按行求均值 1*n的向量
            var = ((X - mean) ** 2).mean(dim=0)  # 按行求方差, 1*n的向量
        else:
            # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
            # 这里我们需要保持X的形状以便后面可以做广播运算
            mean = X.mean(dim=(0, 2, 3), keepdim=True)  # dim:0批量大小,1输入输出通道,2高,3宽
            # 则需要求出没一行的均值,最终是1 * n * 1 *1 形状
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)  # 同理
        # 训练模式下,用当前的均值和方差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # 更新移动平均的均值和方差,动量更新,i影响i+1
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # 缩放和移位
    return Y, moving_mean.data, moving_var.data


class BatchNorm(nn.Module):
    # 批量归一化层
    # num_features:完全连接层的输出数量或卷积层的输出通道数。
    # num_dims:2表示完全连接层,4表示卷积层
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0
        self.gamma = nn.Parameter(torch.ones(shape))  # gamma不能初始化为0,不然一乘全是0没办法学习了
        self.beta = nn.Parameter(torch.zeros(shape))
        # 非模型参数的变量初始化为0和1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)  # 初始化为0,1正态分布

    def forward(self, X):
        # 如果X不在内存上,将moving_mean和moving_var
        # 复制到X所在显存上
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # 保存更新过的moving_mean和moving_var
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.9)
        return Y


'''应用在LeNet上'''
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(16 * 4 * 4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),
    nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),
    nn.Linear(84, 10))

'''简洁实现,使用nn.BatchNorm2d,只需要输入通道数作为参数'''
net2 = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),
    nn.Linear(84, 10))

lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
d2l.plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1955689.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

古文:李密《陈情表》

原文 臣密言:臣以险衅,夙遭闵凶。生孩六月,慈父见背;行年四岁,舅夺母志。祖母刘愍臣孤弱,躬亲抚养。臣少多疾病,九岁不行,零丁孤苦,至于成立。既无伯叔,终鲜…

说说你对redis的理解

数据结构 String:缓存对象、常规计数、分布式锁、共享session信息 hash:(包含键值对的无序散列表) list:消息队列 set:聚合计算、点赞、公共关注、抽奖活动 zset:(格式key、val…

【Streamlit学习笔记】Streamlit-ECharts热力图tooltip提示信息拓展

Streamlit-ECharts Streamlit-ECharts是一个Streamlit组件,用于在Python应用程序中展示ECharts图表。ECharts是一个由百度开发的JavaScript数据可视化库Apache ECharts 安装模块库 pip install streamlitpip install streamlit-echarts绘制热力图展示 在基础热力…

【强化学习的数学原理】课程笔记--5(值函数近似,策略梯度方法)

目录 值函数近似一个例子TD 算法的值函数近似形式Sarsa, Q-learning 的值函数近似形式Deep Q-learningexperience replay 策略梯度方法(Policy Gradient)Policy Gradient 的目标函数目标函数 1目标函数 2两种目标函数的同一性 Policy Gradient 目标函数的…

18967 六一儿童节

这个问题可以使用贪心算法来解决。我们可以先将孩子们的需求和巧���力的重量都进行排序,然后从最大的需求开始,找到能满足这个需求的最大的巧克力,将其分给这个孩子。然后继续处理下一个需求,直…

基于微信小程序+SpringBoot+Vue的自助点餐系统(带1w+文档)

基于微信小程序SpringBootVue的自助点餐系统(带1w文档) 基于微信小程序SpringBootVue的自助点餐系统(带1w文档) 基于微信小程序的自助点餐系统前后台分离,让商品订单,用户反馈信息,商品信息等相关信息集中在后台让管理员管理,让用…

【进程间通信机制】管道和 FIFO、信号、消息队列、信号量、共享内存、套接字(Socket)

进程详细剖析,移步:https://blog.csdn.net/Thmos_vader/article/details/140750535 进程间通信 前文介绍:如何通过 fork()或 vfork()创建子进程,以及在子进程中通过 exec()函数执行一个新的程序; 谓进程间通信指的是…

考题相似度 AI 分析 API 数据接口

考题相似度 AI 分析 API 数据接口 基于 AI 的相似度评估,专有 AI 模型,包含评估详情 。 1. 产品功能 基于自有专业模型进行 AI 智能分析;提供详细的相似度评分和结果描述;高效的模型分析性能;全接口支持 HTTPS&#…

乐鑫ESP32-H2设备联网芯片,集成多种安全功能方案,启明云端乐鑫代理商

在数字化浪潮的推动下,物联网正以前所未有的速度融入我们的日常生活。然而,随着设备的激增,安全问题也日益成为公众关注的焦点。 乐鑫ESP32-H2致力于为所有开发者提供高性价比的安全解决方案,这款芯片经过专门设计以集成多种安全…

【时时三省】unity test 测试框架 下载

目录 1,unity test 测试框架介绍 2,源码下载 3,目录架构 4,git for window 下载安装方法: 1,unity test 测试框架介绍 Unity是一个用于C语言的轻量级单元测试框架。它由Throw The Switch团队开发&#…

工作很难受,还要不要继续留在职场上?

先说结论:我非常赞同大家离开职场 虽然小编现实的工作是有关于人力资源的,高级点叫做猎头,低俗点讲就叫“人贩子” 原因可能和其他人不太一样,大家自行理解 1.现在的社会资源太少,“蛋糕”太小 大家要明白最重要的…

TVL 破 3 亿美元的 Pencils Protocol,缘何具备持续盈利的能力?

Pencils Protocol 是行业内首个 DeFi Auction 的一站式聚合收益平台,其不仅支持 LaucnhPad、Staking、杠杆挖矿等系列功能,并有望成为 Scroll 生态重要的流动性枢纽,其目前正在基于该体系为 LRT 赛道赋能,目前在质押端不仅支持 ST…

【公式】因果卷积神经网络公式与应用解析

因果卷积神经网络公式与应用解析 因果卷积神经网络的核心作用 因果卷积神经网络(Temporal Convolutional Network, TCN)是一种专为时间序列预测而设计的网络结构。它通过因果卷积层,能够有效地处理时间序列数据,捕捉时间序列中的…

mediasoup simulcast实现说明

一. 前言 二. 空间可伸缩与时间可伸缩 三. mediasoup simulcast实现代码分析 1. 推流客户端开启 simulcast 2. mediasoup服务端接收simulcast流 3. mediasoup服务端转发流数据给消费者 a. SimulcastConsumer类声明 b. 获取预估码率,切换SimulcastConsumer的目…

大脑自组织神经网络通俗讲解

大脑自组织神经网络的核心概念 大脑自组织神经网络,是指大脑中的神经元通过自组织的方式形成复杂的网络结构,从而实现信息的处理和存储。这一过程涉及到神经元的生长、连接和重塑,是大脑学习和记忆的基础。其核心公式涉及神经网络的权重更新…

优化算法:2.粒子群算法(PSO)及Python实现

一、定义 粒子群算法(Particle Swarm Optimization,PSO)是一种模拟鸟群觅食行为的优化算法。想象一群鸟在寻找食物,每只鸟都在尝试找到食物最多的位置。它们通过互相交流信息,逐渐向食物最多的地方聚集。PSO就是基于这…

探索HTTPx:Python中的HTTP客户端新选择

文章目录 探索HTTPx:Python中的HTTP客户端新选择背景什么是HTTPx?安装HTTPx简单的库函数使用方法发送GET请求发送POST请求设置超时使用代理处理Cookies 应用场景异步请求连接池管理重试机制 常见问题与解决方案问题1:超时错误问题2&#xff1…

ROS getting started

文章目录 前言一、认识ROS提供的命令行工具nodestopicsservicesparametersactionsrqt_console, rqt_graph批量启动多个节点recorde and playc基础pub-sub 1.5 ROS2和fastdds1 改变订阅模式2 xml配置3 指定xml位置4 talker/listener通过发现服务器发送topic5 ros2 检视6 远程fas…

Natutre Methods|单细胞+空间转录,值得去复现的开源单细胞分析pipeline

肺癌是全球第二大最常见的癌症,也是癌症相关死亡的主要原因。肿瘤生态系统具有多种免疫细胞类型。尤其是髓系细胞,髓系细胞普遍存在,并且在促进疾病方面发挥着众所周知的作用。该篇通过单细胞和空间转录组学分析了 25 名未经治疗的腺癌和鳞状…