PyTorch 中cumprod函数计算张量沿指定维度的累积乘积详解和代码示例

news2025/6/8 12:18:35

torch.cumprod 是 PyTorch 中用于 计算张量沿指定维度的累积乘积(cumulative product) 的函数。


1、函数原型

torch.cumprod(input, dim, *, dtype=None, out=None) → Tensor

参数说明:

参数说明
input输入张量
dim累积乘积的维度
dtype可选:指定输出类型(默认与输入类型相同)
out可选:输出张量(用于 inplace)

2、功能说明

对于指定维度 dim,返回一个张量,其中每个元素是该位置及之前所有元素的乘积。


3、示例代码

示例 1:一维张量

import torch

x = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
y = torch.cumprod(x, dim=0)
print("输入:", x)
print("累积乘积:", y)

输出:

输入: tensor([1., 2., 3., 4.])
累积乘积: tensor([ 1.,  2.,  6., 24.])

示例 2:二维张量,沿 dim=0(列)

x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]], dtype=torch.float32)

y = torch.cumprod(x, dim=0)
print(y)

输出:

tensor([[  1.,   2.,   3.],
        [  4.,  10.,  18.],
        [ 28.,  80., 162.]])

计算过程解释(逐列):

  • 第 1 列: [1, 4, 7][1, 1×4=4, 4×7=28]
  • 第 2 列: [2, 5, 8][2, 2×5=10, 10×8=80]
  • 第 3 列: [3, 6, 9][3, 3×6=18, 18×9=162]

示例 3:使用 dtype 强制类型

x = torch.tensor([1, 2, 3], dtype=torch.int32)
y = torch.cumprod(x, dim=0, dtype=torch.float32)
print(y)

输出:

tensor([1., 2., 6.])

4、综合应用示例

下面是一个完整的示例,展示了 torch.cumprod 在神经网络训练中如何用于 前向传播中累积权重乘积的计算。这种用法常见于:

  • 路径权重乘积模型(Path Weight Product Models)
  • 自定义神经网络结构中累积乘积(如神经ODE、概率模型)

4.1、示例背景

假设我们有一个网络结构:每一层只有一个权重因子,我们要计算所有权重乘积作为 forward 输出的一部分。


4.2、示例代码:累积权重乘积的自定义网络

import torch
import torch.nn as nn

class CumprodNet(nn.Module):
    def __init__(self, num_layers):
        super(CumprodNet, self).__init__()
        # 每层一个标量权重参数,初始化为 0.9 左右
        self.weights = nn.Parameter(torch.rand(num_layers) * 0.2 + 0.9)

    def forward(self, x):
        # 假设 x 是输入标量或批量张量
        # 计算权重的累积乘积
        path_weights = torch.cumprod(self.weights, dim=0)
        
        # 将每层的路径加权输出加总
        outputs = torch.stack([x * pw for pw in path_weights], dim=0)
        return outputs.sum(dim=0), path_weights  # 返回结果和路径乘积向量

# 初始化模型
model = CumprodNet(num_layers=4)

# 输入张量(可批量)
x = torch.tensor([1.0], requires_grad=True)

# 前向传播
output, path_weights = model(x)

# 打印结果
print("权重参数:", model.weights.data)
print("累积乘积:", path_weights)
print("最终输出:", output)

# 反向传播
output.backward()
print("输入梯度:", x.grad)

4.3、输出说明(示例)

假设 self.weights = [0.91, 0.95, 1.01, 1.05]

cumprod 将计算:

[0.91,
 0.91 × 0.95 = 0.8645,
 0.8645 × 1.01 = 0.8731,
 0.8731 × 1.05 ≈ 0.9167]

然后每个都乘上输入 x,最后加总作为最终输出。


4.4、应用场景

  1. 路径加权神经网络
  2. 可学习的指数衰减控制
  3. 自定义 RNN、深层残差控制器中的动态路径参数建模
  4. 强化学习中的路径概率分布建模(Policy Gradient)

5、注意事项

  • cumprod 会在指定维度上,按顺序相乘;
  • 输入中如果有 0,后续的所有乘积都会变为 0
  • 常用于概率连乘、对数空间建模前的准备步骤(比如前向链式法则)。

6、与相关函数对比

函数功能
torch.cumsum累加和
torch.cumprod累乘积
torch.prod所有元素乘积(非逐步)
torch.cummax / cummin累积最大/最小值

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2404119.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据通信与计算机网络——数字传输

主要内容 数字到数字转换 线路编码 线路编码方案 块编码 扰动 模拟到数字转换 脉冲码调制(PCM) Delta调制(DM) 传输模式 并行传输 串行传输 一、数字到数字转换 将数字数据转换为数字信号涉及三种技术: 线…

黄柏基因组-小檗碱生物合成的趋同进化-文献精读142

Convergent evolution of berberine biosynthesis 小檗碱生物合成的趋同进化 摘要 小檗碱是一种有效的抗菌和抗糖尿病生物碱,主要从不同植物谱系中提取,特别是从小檗属(毛茛目,早期分支的真双子叶植物)和黄柏属&…

前端杂货铺——TodoList

个人简介 👀个人主页: 前端杂货铺 🙋‍♂️学习方向: 主攻前端方向,正逐渐往全干发展 📃个人状态: 研发工程师,现效力于中国工业软件事业 🚀人生格言: 积跬步…

Spring Boot SSE流式输出+AI消息持久化升级实践:从粗暴到优雅的跃迁

在 AI 应用落地过程中,我们常常需要将用户和 AI 的对话以“完整上下文”的形式持久化到数据库中。但当 AI 回复非常长,甚至接近上万字时,传统的单条消息保存机制就会出问题。 在本篇文章中,我将深入讲解一次实际项目中对 对话持久…

Model Context Protocol (MCP) 是一个前沿框架

微软发布了 Model Context Protocol (MCP) 课程:mcp-for-beginners。 Model Context Protocol (MCP) 是一个前沿框架,涵盖 C#、Java、JavaScript、TypeScript 和 Python 等主流编程语言,规范 AI 模型与客户端应用之间的交互。 MCP 课程结构 …

内容力重塑品牌增长:开源AI大模型驱动下的智能名片与S2B2C商城赋能抖音生态种草范式

摘要:内容力已成为抖音生态中品牌差异化竞争的核心能力,通过有价值、强共鸣的内容实现产品"种草"与转化闭环。本文基于"开源AI大模型AI智能名片S2B2C商城小程序源码"技术架构,提出"技术赋能内容"的新型种草范式…

手机号在网状态查询接口如何用PHP实现调用?

一、什么是手机号在网状态查询接口 通过精准探测手机号的状态,帮助平台减少此类问题的发生,提供更个性化的服务或进行地域性营销 二、应用场景 1. 金融风控 通过运营商在网态查询接口,金融机构可以核验贷款申请人的手机状态,拦…

【Java微服务组件】分布式协调P4-一文打通Redisson:从API实战到分布式锁核心源码剖析

欢迎来到啾啾的博客🐱。 记录学习点滴。分享工作思考和实用技巧,偶尔也分享一些杂谈💬。 有很多很多不足的地方,欢迎评论交流,感谢您的阅读和评论😄。 目录 引言Redisson基本信息Redisson网站 Redisson应用…

一个简单的德劳内三角剖分实现

德劳内(Delaunay)三角剖分是一种经典的将点集进行三角网格化预处理的手段,在NavMesh、随机地牢生成等场景下都有应用。 具体内容百度一大堆,就不介绍了。 比较知名的算法是Bowyer-Watson算法,也就是逐点插入法。 下雨闲…

C#子线程更新主线程UI及委托回调使用示例

1.声明线程方法 2.线程中传入对象 3.声明委托与使用 声明委托对象 委托作为参数传入方法 4.在线程中传入委托 5.调用传入的委托

使用VuePress2.X构建个人知识博客,并且用个人域名部署到GitHub Pages中

使用VuePress2.X构建个人知识博客,并且用个人域名部署到GitHub Pages中 什么是VuePress VuePress 是一个以 Markdown 为中心的静态网站生成器。你可以使用 Markdown 来书写内容(如文档、博客等),然后 VuePress 会帮助你生成一个…

手写Promise.all

前言 之前在看远方os大佬直播的时候看到有让手写的Promise.all的问题,然后心血来潮自己准备手写一个 开始 首先,我们需要明确原本js提供的Promise.all的特性 Promise.all返回的是一个Promise如果传入的数据中有一个reject即整个all返回的就是reject&…

2025年6月|注意力机制|面向精度与推理速度提升的YOLOv8模型结构优化研究:融合ACmix的自研改进方案

版本: 8.3.143(Ultralytics YOLOv8框架) ACmix模块原理 在目标检测任务中,小目标(如裂缝、瑕疵、零件边缘等)由于其尺寸较小、纹理信息稀疏,通常更容易受到图像中复杂背景或噪声的干扰,从而导致漏检或误检…

利用qcustomplot绘制曲线图

本文详细介绍了qcustomplot绘制曲线图的流程,一段代码一段代码运行看效果。通过阅读本文,读者可以了解到每一项怎么用代码进行配置,进而实现自己想要的图表效果。(本文只针对曲线图) 1 最简单的图形(入门&…

【基础算法】枚举(普通枚举、二进制枚举)

文章目录 一、普通枚举1. 铺地毯(1) 解题思路(2) 代码实现 2. 回文日期(1) 解题思路思路一:暴力枚举思路二:枚举年份思路三:枚举月日 (2) 代码实现 3. 扫雷(2) 解题思路(2) 代码实现 二、二进制枚举1. 子集(1) 解题思路(2) 代码实现 2. 费解的…

智能对联网页小程序的仓颉之旅

#传统楹联遇上AI智能体:我的Cangjie Magic开发纪实 引言:一场跨越千年的数字对话 "云对雨,雪对风,晚照对晴空"。昨天晚上星空璀璨,当我用仓颉语言写下第一个智能对联网页小程序的Agent DSL代码时&#xff0…

Python分形几何可视化—— 复数迭代、L系统与生物分形模拟

Python分形几何可视化—— 复数迭代、L系统与生物分形模拟 本节将深入探索分形几何的奇妙世界,实现Mandelbrot集生成器和L系统分形树工具,并通过肺部血管分形案例展示分形在医学领域的应用。我们将使用Python的NumPy进行高效计算,结合Matplo…

【超详细】英伟达Jetson Orin NX-YOLOv8配置与TensorRT测试

文章主要内容如下: 1、基础运行环境配置 2、Torch-GPU安装 3、ultralytics环境配置 4、Onnx及TensorRT导出详解 5、YOLOv8推理耗时分析 基础库版本:jetpack5.1.3, torch-gpu2.1.0, torchvision0.16.0, ultralytics8.3.146 设备的软件开发包基础信息 需…

Go语言学习-->项目中引用第三方库方式

Go语言学习–>项目中引用第三方库方式 1 执行 go mod tidy 分析引入的依赖有没有正常放在go.mod里面 找到依赖的包会自动下载到本地 并添加在go.mod里面 执行结果: 2 执行go get XXXX(库的名字)

每日Prompt:云朵猫

提示词 仰视,城镇的天空,一片形似猫咪的云朵,用黑色的简笔画,勾勒出猫咪的形状,可爱,俏皮,极简