Pytorch平均池化nn.AvgPool2d()使用记录

news2025/7/19 12:38:43

【pytorch官方文档】:https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html?highlight=avgpool2d#torch.nn.AvgPool2d

torch.nn.AvgPool2d()

作用

在由多通道组成的输入特征中进行2D平均池化计算

函数

torch.nn.AvgPool2d(kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None)

参数

Args:
    kernel_size: 滑窗(池化核)大小
    stride: 滑窗的移动步长, 默认值为kernel_size
    padding: 在输入信号两侧的隐式零填充数量
    ceil_mode: 决定计算输出的形状时是向上取整还是向下取整, 默认为False(向下取整)
    count_include_pad: 在平均池化计算中是否包含零填充, 默认为True(包含零填充)
    divisor_override: 如果指定了, 它将被作为平均池化计算中的除数, 否则将使用池化区域的大小作为平均池化计算的除数

公式

代码实例

假设输入特征为S,输出特征为D

情况一

ceil_mode=False, count_include_pad=True(计算时包含零填充)

import torch
import torch.nn as nn
import numpy as np


# 生成一个形状为1*1*3*3的张量
x1 = np.array([
              [1,2,3],
              [4,5,6],
              [7,8,9]
            ])
x1 = torch.from_numpy(x1).float()
x1 = x1.unsqueeze(0).unsqueeze(0)

# 实例化二维平均池化
avgpool1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False, count_include_pad=True)
y1 = avgpool1(x1)
print(y1)

# 打印结果
'''
tensor([[[[1.3333, 1.7778],
          [2.6667, 3.1111]]]])
'''

计算过程:

输出形状= floor[(3 - 3 + 2) / 2] + 1 = 2,

D[1,1] = (0+0+0+0+1+2+0+4+5) / 9 = 1.3333,

D[1,2] = (0+0+0+2+3+0+5+6+0) / 9 = 1.7778,

D[2,1] = (0+4+5+0+7+8+0+0+0) / 9 = 2.6667,

D[2,2] = (5+6+0+8+9+0+0+0+0) / 9 = 3.1111.

情况二

ceil_mode=False, count_include_pad=False(计算时不包含零填充)

avgpool2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False, count_include_pad=False)

y2 = avgpool2(x1)
print(y2)

# 打印结果
'''
tensor([[[[3., 4.],
          [6., 7.]]]])
'''

计算过程:

输出形状= floor[(3 - 3 + 2) / 2] + 1 = 2,

D[1,1] = (1+2+4+5) / 4 = 3,

D[1,2] = (2+3+5+6) / 4 = 4,

D[2,1] = (4+5+7+8) / 4 = 6,

D[2,2] = (5+6+8+9) / 4 = 7.

情况三

ceil_mode=False, count_include_pad=False, divisor_override=2(将计算平均池化时的除数指定为2)

avgpool3 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False, count_include_pad=False, divisor_override=2)

y3 = avgpool3(x1)
print(y3)

# 打印结果
'''
tensor([[[[ 6.,  8.],
          [12., 14.]]]])
'''

计算过程:

输出形状= floor[(3 - 3 + 2) / 2] + 1 = 2,

D[1,1] = (1+2+4+5) / 2 = 6,

D[1,2] = (2+3+5+6) / 2 = 8,

D[2,1] = (4+5+7+8) / 2 = 12,

D[2,2] = (5+6+8+9) / 2 = 14.

情况四

ceil_mode=True, count_include_pad=True, divisor_override=None(在计算输出的形状时向上取整)

x2 = np.array([
              [1,2,3,4],
              [5,6,7,8],
              [9,10,11,12],
              [13,14,15,16]
              ])
x2 = torch.from_numpy(x2).reshape(1,1,4,4).float()
avgpool4 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True)
y4 = avgpool4(x2)
print(y4)

# 打印结果
'''
tensor([[[[ 1.5556,  3.3333,  2.0000],
          [ 6.3333, 11.0000,  6.0000],
          [ 4.5000,  7.5000,  4.0000]]]])
'''

计算过程:

输出形状 = ceil[(4 - 3 + 2) / 2] + 1 = 3,

D[1,1] = (0+0+0+0+1+2+0+5+6) / 9 = 1.5556,

D[1,2] = (0+0+0+2+3+4+6+7+8) / 9 = 3.3333,

D[1,3] = (0+0+4+0+8+0) / 6 = 2,

D[2,1] = (0+5+6+0+9+10+0+13+14) / 9 = 6.3333,

D[2,2] = (6+7+8+10+11+12+14+15+16) / 9 = 11,

D[2,3] = (8+0+12+0+16+0) / 6 = 6,

D[3,1] = (0+13+14+0+0+0) / 6 = 4.5,

D[3,2] = (14+15+16+0+0+0) / 6 = 7.5,

D[3,3] = (16+0+0+0) / 4 = 4.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/362167.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

在 ubuntu 中切换使用不同版本的 python

引言有时我们不得不在同一台 ubuntu 中使用不同版本的 python 环境。本文的介绍就是可以在 ubuntu 上同时安装几个不同版本的 python,然后你可以随时指定当前要使用的 python 版本。步骤检查当前的 python 版本$ python3 --version python 3.6.8我的版本是 3.6.8假设…

Renegade:基于MPC+Bulletproofs构建的anonymous DEX

1. 引言 白皮书见: Renegade Whitepaper: Protocol Specification, v0.6 开源代码见: https://github.com/renegade-fi/renegade(Renegade p2p网络每个节点的核心网络和密码逻辑)https://github.com/renegade-fi/mpc-bulletpr…

OSPF(开放式最短路径优先协议)、ACL(访问控制列表)、NAT

一、OSPF -- (开放式最短路径优先协议) 基于组播更新 --- 224.0.0.5 224.0.0.6 1、协议类型:无类别链路状态的IGP协议 无类别:带精确掩码链路状态:不共享路由,共享拓扑(共享LSA)…

Windows平台Python编程必会模块之pywin32

在Windows平台上,从原来使用C/C编写原生EXE程序,到使用Python编写一些常用脚本程序,成熟的模块的使用使得编程效率大大提高了。 不过,python模块虽多,也不可能满足开发者的所有需求。而且,模块为了便于使用…

产品未出 百度朋友圈“开演”

ChatGPT这股AI龙卷风刮到国内时,人们齐刷刷望向百度,这家在国内对AI投入最高的公司最终出手了,大模型新项目文心一言(ERNIE Bot)将在3月正式亮相,对标微软投资的ChatGPT。 文心一言产品未出,百…

[python入门㊿] - python如何打断点

目录 ❤ 什么是bug(缺陷) ❤ python代码的调试方式 ❤ 使用 pdb 进行调试 测试代码示例 利用 pdb 调试 退出 debug debug 过程中打印变量 停止 debug 继续执行程序 debug 过程中显示代码 使用函数的例子 对函数进行 debug 在调试的时候动态改变值 ❤ 使用 PyC…

el-cascader v-model 绑定值改变了,但是界面没变化

查了很多资料,解决办法各异,但以下两个没有用 (1)this.$forceUpdate()强制更新渲染,没用。 (2)使用v-if和this.ifPanel false去控制el-cascader的显示,目的也是重新渲染&#xff…

原生小程序中模板自定义组件事件

封装request.js请求文件目的:优化代码结构以及后期项目版本迭代和维护方便,提升代码的执行速度。假设:在原生page中使用基本写法创建ajax请求//发送请求了wx.request({url:"",method:"",data:"",success(res){//写业务操做…

数据分片(mycat)

1. 数据分片概念: 1.1. 分库分表 什么是分库分表: 将存放在一台数据库服务器中的数据,按照特定方式(指的是程序开发的算法)进行拆分,分散存放到多台数据库服务器中,以达到分散单台服务器负载的…

Vue使用distpicker插件实现省市级下拉框三级联动

前言 这几天做项目,想着用一个全国省市区插件,之前就知道有几种,比如通过JSON文件生成对应的区域下拉框,element-china-are插件,包括distpicker插件 今天主要介绍的是如何使用distpicker插件实现省市级三联跳动 官网…

2023年100道最新Android面试题,常见面试题及答案汇总

除了需要掌握牢固的专业技术之外,还需要刷更多的面试去在众多的面试者中杀出重围。小编特意整理了100道Android面试题,送给大家,希望大家都能顺利通过面试,拿下高薪。赶紧拿去吧~~文末有答案Q1.组件化和arouter原理Q2.自定义view&…

钣金行业mes解决方案,缩短产品在制周期

钣金加工行业具有多品种、小批量离散制造行业的典型特点。一些常见的下料车间、备料车间、冲压车间、冲剪生产线等。一般来说,核心业务是钣金加工的生产单位。 一般来说,与大规模生产相比,这种生产方式效率低、成本高,自动化难度…

ur3+robotiq ft sensor+robotiq 2f 140配置gazebo仿真环境

ur3robotiq ft sensorrobotiq 2f 140配置gazebo仿真环境 搭建环境: ubuntu: 20.04 ros: Nonetic sensor: robotiq_ft300 gripper: robotiq_2f_140_gripper UR: UR3 通过上一篇博客配置好ur3、力传感器和robotiq夹爪的rviz仿真环境后,现在来配置一下对…

【读书笔记】《深入浅出数据分析》第一章 分解数据

阅读第一章后,觉得本章重点不是在“分解数据”上,而是在对分析流程,分析步骤的引导。 1,确定问题 当业务方或者leader给你提诉求时,往往都是会比较模糊,他们会简单的说下诉求,然后给你一些数据…

Spark介绍

1、Spark是什么?类似与Hadoop的MapReduce的计算框架,基于map和reduce实现分布式计算,对比MapReduce可有效减少落盘次数,增加效率.任务之间通信交互不需要落盘,仅在shuffle时需要重新将数据排序分区落盘.Spark的缓存功能更加高效,特别是在SparkSQL中,一般是以列式存…

学习.NET MAUI Blazor(六)、基于OpenAI接口的伪ChatGPT

ChatGPT不用介绍了。自从1月份开始到现在,火的不得了。网络上也充斥着各种教程,甚至还有号称是ChatGPT国内版的。那么ChatGPT到底有么有开放的API接口,那些打着ChatGPT的应用到底是如何实现的呢? 其实,国内环境虽然无法…

day49【代码随想录】动态规划之最长公共子序列、不相交的线、最大子序和、判断子序列

文章目录前言一、最长公共子序列(力扣1143)二、不相交的线(力扣1035)三、最大子序和(力扣53)四、判断子序列(力扣392)前言 1、最长公共子序列 2、不相交的线 3、最大子序和 4、判断…

C++012-C++一维数组

文章目录C012-C一维数组一维数组目标一维数组定义一维数组初始化一维数组输入输出题目描述 车厢货物**需要查看指定车厢的货物****倒着输出**题目描述 与指定数字相同的数的个数排序选择排序实现题目描述 成绩排名成绩第一名和最后一名在线练习:总结C012-C一维数组 …

推荐几个好玩的AI工具和办公效率网站!

1.copymonkey.ai CopyMonkey可以帮助用户生成和优化亚马逊商品列表文案,分析竞争对手,帮助推动销售。 2.bertha.ai Bertha可以帮助用户使用AI创建高质量的营销文案和图片,同时,帮助客户优化业务流程,提高效率。 3.cr…

UDP端口转发

sokit是一个开源项目,是一个TCP / UDP 测试工具,用来接收,发送,转发TCP或UDP数据包。 项目地址: http://code.google.com/p/sokit/、https://github.com/sinpolib/sokit。 中文版下载地址:https://download.csdn.net/download/android_cai_niao/874728…