pytorch小记(二十):深入解析 PyTorch 的 `torch.randn_like`:原理、参数与实战示例

news2025/5/21 4:51:15

pytorch小记(二十):深入解析 PyTorch 的 `torch.randn_like`:原理、参数与实战示例

    • 一、函数签名与参数详解
    • 二、`torch.randn_like` vs `torch.randn`
    • 三、基础示例
    • 四、进阶用法与参数覆盖
      • 4.1 覆盖数据类型(dtype)
      • 4.2 覆盖设备(device)
      • 4.3 开启梯度追踪(requires\_grad)
      • 4.4 覆盖内存格式(memory\_format)
    • 五、典型应用场景
      • 1. 给模型参数添加噪声
      • 2. 数据增强:图像高斯噪声
      • 3. 扩散模型(DDPM)中的噪声采样
    • 六、多种等价写法
    • 七、小结


在深度学习模型中,我们经常需要在已有张量的基础上生成与之「同形状」「同设备」「同或不同数据类型」的随机噪声,用于参数扰动、数据增强、扩散模型等场景。PyTorch 为我们提供了一个高效便捷的工具——torch.randn_like,它能一步完成上述需求。本文将从函数定义、参数详解、典型应用场景,到进阶用法,全面剖析 torch.randn_like,并通过丰富示例帮助你快速上手。


一、函数签名与参数详解

torch.randn_like(
    input: Tensor,
    *,
    dtype: Optional[torch.dtype] = None,
    layout: Optional[torch.layout] = None,
    device: Optional[torch.device] = None,
    requires_grad: bool = False,
    memory_format: Optional[torch.memory_format] = None
) → Tensor
  • input(必选)
    源张量,randn_like 会读取它的 .shape.dtype.device.layout、以及 memory_format(如果未显式指定覆盖项)。

  • dtype(可选)
    生成张量的数据类型,如 torch.float32torch.int64 等。若不指定,则继承 input.dtype

  • device(可选)
    指定在 CPU 还是 GPU 上创建新张量,如 "cpu""cuda:0"。若不指定,则继承 input.device

  • requires_grad(可选)
    是否对新张量开启梯度追踪,默认为 False

  • 其他

    • layout:张量内存布局,一般使用默认;
    • memory_format:指定内存格式,如 torch.contiguous_format

二、torch.randn_like vs torch.randn

方法参数优点
torch.randn(size)必须手动传入 size、可选传入 dtypedevice简单直观,适合只关心形状的场景
torch.randn_like(input)自动继承 input.shapedtypedevicelayout 等属性减少样板代码,保证输出张量与输入环境一致

三、基础示例

import torch

# 1. 构造一个形状为 (2, 3) 的零张量
x = torch.zeros(2, 3)
print("x:", x.shape, x.dtype, x.device)
# x: torch.Size([2, 3]) torch.float32 cpu

# 2. 生成与 x 同形状同属性的标准正态随机张量
noise = torch.randn_like(x)
print("noise:", noise)
# 示例输出:
# tensor([[-0.1245,  0.5487, -0.3221],
#         [ 0.8723, -1.0054,  0.0392]])
  • 新张量 noisex 的形状、数据类型、设备保持一致。

四、进阶用法与参数覆盖

4.1 覆盖数据类型(dtype)

# 强制生成 float64 类型
noise_fp64 = torch.randn_like(x, dtype=torch.float64)
print(noise_fp64.dtype)  # torch.float64

4.2 覆盖设备(device)

if torch.cuda.is_available():
    noise_gpu = torch.randn_like(x, device=torch.device('cuda:0'))
    print(noise_gpu.device)  # cuda:0

4.3 开启梯度追踪(requires_grad)

noise_grad = torch.randn_like(x, requires_grad=True)
print(noise_grad.requires_grad)  # True

4.4 覆盖内存格式(memory_format)

noise_contig = torch.randn_like(x, memory_format=torch.contiguous_format)
# 通常无需显式指定,除非对内存布局有特殊需求

五、典型应用场景

1. 给模型参数添加噪声

在对抗训练、参数平滑或元学习中,需要对权重做微小扰动:

import torch.nn as nn

class NoisyLinear(nn.Linear):
    def forward(self, input):
        # 为权重张量添加微小高斯噪声
        weight_noise = torch.randn_like(self.weight) * 0.01
        return nn.functional.linear(input, self.weight + weight_noise, self.bias)

layer = NoisyLinear(128, 64)
x = torch.randn(32, 128)
out = layer(x)  # 前向过程中,自动生成同形状噪声

2. 数据增强:图像高斯噪声

对图像 Batch 注入随机噪声,提升模型鲁棒性:

# 假设 images 形状为 [B, C, H, W]
images = torch.randn(16, 3, 224, 224)  # 示例输入
noise_std = 0.1

noisy_images = images + torch.randn_like(images) * noise_std
# 这样可以保证噪声形状 / dtype / device 与 images 完全一致

3. 扩散模型(DDPM)中的噪声采样

在扩散模型中,需要不断向数据添加标准正态噪声,且噪声张量形状与数据完全对齐:

def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
    # 根据时间步 t 计算噪声比例等后续操作...
    return x_start * alpha_t[t] + noise * beta_t[t]

六、多种等价写法

  • tensor.long()tensor.to(torch.int64)
  • tensor.type(torch.float32) 等方法,均可对已有张量做类型转换,与 randn_like 结合时常用于进一步处理。

七、小结

  1. 功能torch.randn_like 快速生成与指定张量同形状、同设备的标准正态分布随机张量。
  2. 参数覆盖:可选 dtypedevicerequires_gradmemory_format 等,灵活适配各种需求。
  3. 典型场景:参数扰动、数据增强、扩散模型、随机索引等。
  4. 最佳实践:在不关心形状等属性细节时,用 randn_like 省去 boilerplate;在需要覆盖属性时,通过关键字参数一次性完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2380448.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

8-游戏详情制作(Navigation组件)

1.1 需求 使用Navigation实现游戏主详情视图,从瀑布流容器中的游戏项(游戏中心首页-游戏瀑布流列表)点击游戏后进入游戏详情页,从游戏详情页可以返回游戏列表主页。 1.2 界面原型 从瀑布流组件进入: 游戏详情&#…

Unity引擎源码-物理系统详解-其二

继续我们关于Unity的物理系统的源码阅读,不过这一次我们的目标是PhysX引擎——这个Unity写了一堆脚本来调用API的实际用C写成的底层物理引擎。 Github的地址如下:NVIDIA-Omniverse/PhysX: NVIDIA PhysX SDK (github.com) 下载后发现由三个文件组成&…

1.3.3 数据共享、汇聚和使用中的安全目标

探索数据共享、汇聚与使用中的安全目标 在当今数字化时代,数据的价值愈发凸显,数据共享、汇聚与使用成为了推动业务发展、促进创新的重要环节。然而,在这一过程中,数据安全至关重要,我们需要明确并保障保密性、完整性…

【Docker】Docker安装Redis

目录 1.下载镜像 1.1查看下载的镜像 2.创建挂载目录 3.创建容器并启动 4.测试连接 1.下载镜像 根据指令下载镜像文件 docker pull redis#上面指令是下载最新,如需下载指定版本可带版本号 docker pull redis:xxx 响应内容: 1.1查看下载的镜像 下载完…

Oc语言学习 —— Foundation框架总结

1、NSString类 我们对一个NSString对象赋值的方法是直接将字符串常量赋给对象,例如:NSString *str "hello"; 因为我们的NSString是不可变的,所以我们只能通过一些方法来在我们原来的字符串后面追加或初始化我们的字符串来间接修改…

LWIP的Socket接口

Socket接口简介 类似于文件操作的一种网络连接接口,通常将其称之为“套接字”。lwIP的Socket接口兼容BSD Socket接口,但只实现完整Socket的部分功能 netconn是对RAW的封装 Socket是对netconn的封装 SOCKET结构体 struct sockaddr { u8_t sa_len; /* 长…

Better Faster Large Language Models via Multi-token Prediction 原理

目录 模型结构: Memory-efficient implementation: 实验: 1. 在大规模模型上效果显著: 2. 在不同类型任务上的效果: 为什么MLP对效果有提升的几点猜测: 1. 并非所有token对生成质量的影响相同 2. 关…

Spring的Validation,这是一套基于注解的权限校验框架

为了保证数据的正确性、完整性,作为一名后端开发工程师,不能仅仅依靠前端来校验数据,还需要对接口请求的参数进行后端的校验。 controller 全局异常处理器 在项目中添加一个全局异常处理器,处理校验异常 RestControllerAdvice p…

MySQL - 如何突破单库性能瓶颈

数据库服务器硬件优化 我们来看看对数据库所在的服务器是如何进行优化的,服务器是数据库的宿主,其性能直接影响了数据库的性能,所以服务器的优化也是数据库优化的第一步。 数据库服务器通常是从 CPU、内存、磁盘三个角度进行硬件优化的&…

apisix透传客户端真实IP(real-ip插件)

文章目录 apisix透传客户端真实IP需求和背景apisix real-ip插件为什么需要 trusted_addresses?安全架构的最佳实践 示例场景apisix界面配置 apisix透传客户端真实IP 需求和背景 当 APISIX 前端有其他反向代理(如 Nginx、HAProxy、云厂商的 LB&#xff…

Oracle 数据库的默认隔离级别

Oracle 数据库的默认隔离级别 默认隔离级别:READ COMMITTED Oracle 默认使用 读已提交(READ COMMITTED) 隔离级别,这是大多数OLTP(在线事务处理)系统的标准选择。 官方文档 https://docs.oracle.com/en/database/oracle/oracle-database/19/cncpt/da…

代码随想录算法训练营第六十四天| 图论9—卡码网47. 参加科学大会,94. 城市间货物运输 I

每日被新算法方式轰炸的一天,今天是dijkstra(堆优化版)以及Bellman_ford ,尝试理解中,属于是只能照着代码大概说一下在干嘛。 47. 参加科学大会 https://kamacoder.com/problempage.php?pid1047 dijkstra&#xff08…

开启健康生活的多元养生之道

健康养生是一门值得终身学习的学问,在追求健康的道路上,除了常见方法,还有许多容易被忽视却同样重要的角度。掌握这些多元养生之道,能让我们的生活更健康、更有品质。​ 室内环境的健康不容忽视。定期清洁空调滤网,避…

【Vite】前端开发服务器的配置

定义一些开发服务器的行为和代理规则 服务器的基本配置 server: {host: true, // 监听所有网络地址port: 8081, // 使用8081端口open: true, // 启动时自动打开浏览器cors: true // 启用CORS跨域支持 } 代理配置 proxy: {/api: {target: https://…

Spring Security与Spring Boot集成原理

Spring Security依赖的是过滤器机制,首先是web容器例如tomcat作为独立的产品,本身有自己的一套过滤器机制用来处理请求,那么如何将tomcat接收到的请求转入到Spring Security的处理逻辑呢?spring充分采用了tomcat的拓展机制提供了t…

VScode各文件转化为PDF的方法

文章目录 代码.py文件.ipynb文本和代码夹杂的文件方法 1:使用 VS Code 插件(推荐)步骤 1:安装必要插件步骤 2:安装 `nbconvert`步骤 3:间接导出(HTML → PDF)本文遇见了系列错误:解决方案:问题原因步骤 1:降级 Jinja2 至兼容版本步骤 2:确保 nbconvert 版本兼容替代…

Vue3学习(组合式API——Watch侦听器、watchEffect()详解)

目录 一、Watch侦听器。 (1)侦听单个数据。 (2)侦听多个数据。(数组写法?!) (3)immediate参数。(立即执行回调) (3)deep参数。(深层监…

【node.js】安装与配置

个人主页:Guiat 归属专栏:node.js 文章目录 1. Node.js简介1.1 Node.js的特点1.2 Node.js架构 2. Node.js安装2.1 下载和安装方法2.1.1 Windows安装2.1.2 macOS安装2.1.3 Linux安装 2.2 使用NVM安装和管理Node.js版本2.2.1 安装NVM2.2.2 使用NVM管理Node…

《AI大模型应知应会100篇》第62篇:TypeChat——类型安全的大模型编程框架

第62篇:TypeChat——类型安全的大模型编程框架 摘要 在构建 AI 应用时,一个常见的痛点是大语言模型(LLM)输出的不确定性与格式不一致问题。开发者往往需要手动解析、校验和处理模型返回的内容,这不仅增加了开发成本&a…

EdgeShard:通过协作边缘计算实现高效的 LLM 推理

(2024-05-23) EdgeShard: Efficient LLM Inference via Collaborative Edge Computing (EdgeShard:通过协作边缘计算实现高效的 LLM 推理) 作者: Mingjin Zhang; Jiannong Cao; Xiaoming Shen; Zeyang Cui;期刊: (发表日期: 2024-05-23)期刊分区:本地链接: Zhang 等 - 2024 …