Pytorch中nn.Linear使用方法

news2025/6/28 17:01:42

nn.Linear定义一个神经网络的线性层:

torch.nn.Linear(in_features,             # 输入的神经元个数
                out_features,            # 输出神经元个数
                bias=True                # 是否包含偏置
                )

nn.Linear其实就是对输入x_{n\times i}(n表示样本数量,i表示样本特征数)执行了一个线性变换,即:

Y_{n\times o } = X_{n\times i}W_{i\times o} + b

其中W矩阵是模型要学习的参数,b是1*O的向量偏置(即1行O列),n表示输入向量的个数(也可以理解为行数,比如一次输入100个样本数据,则n=100),i为每个样本的特征数,也可以理解为神经元的个数,O为输出样本的特征数,即输出神经元的个数。

from torch import nn
import torch

model = nn.Linear(3, 1)           # 每个样本输入特征数设置为3,输出特征数设置为1

input = torch.Tensor([2, 4, 6])   # 给一个样本,该样本有3个特征,这3个特征分别是2、4、6
output = model(input)

print("nn.Linear 输出大小:{}".format(output.shape))
print(output)
print("")

print("查看模型参数W和b的值")
# 查看模型参数
for param in model.parameters():
    print(param)

输出:
nn.Linear 输出大小:torch.Size([1])    #输出结果表示只有一个样本输出,且该样本只有一个特征值1
tensor([-0.7842], grad_fn=<AddBackward0>)

查看模型参数W和b的值
Parameter containing:
tensor([[ 0.2353, -0.5686,  0.1759]], requires_grad=True)
Parameter containing:
tensor([-0.0356], requires_grad=True)

可以看到,模型有4个参数,分别为W的三个权重和b的一个偏置。手动计算验证结果:

0.2353*2 + (-0.5686)*4 + 0.1759*6 + (-0.0356) = -0.7839999999999997

假设有5个输入样本A、B、C、D、E(即batch_size为5),每个样本的特征数量为3,定义线性层时,输入特征为3,所以in_feature=3,想让下一层的神经元个数为5,所以out_feature=5,则模型参数为:

model = nn.Linear(in_features=3, out_features=5, bias=True)

此时参数矩阵W大小为3行3列

from torch import nn
import torch

model = nn.Linear(3, 5)           # 每个样本输入特征数设置为3,输出特征数设置为1

input = torch.Tensor([[2, 4, 6],[8,10,12],[14,16,18],[20,22,24],[26,28,30]])   # 给一个样本,该样本有3个特征,这3个特征分别是2、4、6

print(input)

output = model(input)

print("nn.Linear 输出大小:{}".format(output.shape))
print(output)
print("")

print("查看模型参数W和b的值")
# 查看模型参数
for param in model.parameters():
    print(param)

输出:
tensor([[ 2.,  4.,  6.],
        [ 8., 10., 12.],
        [14., 16., 18.],
        [20., 22., 24.],
        [26., 28., 30.]])
nn.Linear 输出大小:torch.Size([5, 5])
tensor([[ -0.9616,  -0.9744,   2.6266,  -0.5605,  -4.2236],
        [ -1.7251,  -4.4417,   5.9969,  -1.3649, -11.0200],
        [ -2.4886,  -7.9090,   9.3673,  -2.1692, -17.8163],
        [ -3.2522, -11.3763,  12.7376,  -2.9736, -24.6127],
        [ -4.0157, -14.8436,  16.1079,  -3.7779, -31.4090]],
       grad_fn=<AddmmBackward>)

查看模型参数W和b的值
Parameter containing:
tensor([[ 0.0714,  0.1456, -0.3443],
        [-0.5098, -0.0893,  0.0211],
        [ 0.3489, -0.2682,  0.4811],
        [ 0.0768, -0.3863,  0.1755],
        [-0.2832, -0.4325, -0.4170]], requires_grad=True)
Parameter containing:
tensor([ 0.3789,  0.2753,  0.1153, -0.2216,  0.5748], requires_grad=True)

第一个样本特征为[2、4、6],输出为[ -0.9616,  -0.9744,   2.6266,  -0.5605,  -4.2236],验证过程如下:

%w是模型参数矩阵
w = [[ 0.0714,  0.1456, -0.3443],
     [-0.5098, -0.0893,  0.0211],
     [ 0.3489, -0.2682,  0.4811],
     [ 0.0768, -0.3863,  0.1755],
     [-0.2832, -0.4325, -0.4170]];
x = [2,4,6];
b = [0.3789,  0.2753,  0.1153, -0.2216,  0.5748];   %偏置向量
x*w'+b

输出:
 -0.9617   -0.9749    2.6269   -0.5602   -4.2236

第2个样本验证:

w = [[ 0.0714,  0.1456, -0.3443],
        [-0.5098, -0.0893,  0.0211],
        [ 0.3489, -0.2682,  0.4811],
        [ 0.0768, -0.3863,  0.1755],
        [-0.2832, -0.4325, -0.4170]];
x = [8,10,12];
b = [0.3789,  0.2753,  0.1153, -0.2216,  0.5748];
x*w'+b

输出:
-1.7255   -4.4429    5.9977   -1.3642  -11.0198

第3、4、5个样本的验证过程类似,从以上验证可以看出,所有样本共享参数矩阵W和偏置b

因为有5个样本,所以相当于依次进行了5次以上操作。

该操作重复了5次,每个样本重复一次:Y_{1\times 5}=X_{1\times 3}W_{3\times 5} + b_{1\times 5}

然后再将5个Y _{1 \times 5}叠加在一起,得到5*5的输出
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1583782.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

国产低代码工具,轻松搞定数据迁移

在日常的业务系统升级或者数据维护过程中&#xff0c;数据迁移是各个企业用户不得不面临的问题&#xff0c;尤其是数据迁移过程中要保障数据完整性、统一性和及时性&#xff0c;同时也需要注意源数据中的数据质量问题&#xff0c;比如缺失、无效、错误等问题&#xff0c;需要在…

Kubernetes中安装部署Nacos集群

目录 1、Nacos安装包的准备 1.1 下载安装包 1.2 解压安装包 1.3 修改配置文件 application.properties 1.4 bin目录下创建 docker-startup.sh 1.5 将nacos-server-1.2.1目录打包成nacos-server-1.2.1.tar.gz 2、 nacos镜像制作 2.1 Dockerfile文件编写 2.2 制作镜像…

单片机入门还能从51开始吗?

选择从51单片机开始入门还是直接学习基于ARM核或RISC核的单片机&#xff0c;取决于学习目标、项目需求以及个人兴趣。每种单片机都有其特定的优势和应用场景&#xff0c;了解它们的特点可以帮助你做出更合适的选择。 首先&#xff0c;我们说一下51单片机的优势&#xff1a; 成熟…

外包干了17天,技术倒退明显

先说情况&#xff0c;大专毕业&#xff0c;18年通过校招进入湖南某软件公司&#xff0c;干了接近6年的功能测试&#xff0c;今年年初&#xff0c;感觉自己不能够在这样下去了&#xff0c;长时间呆在一个舒适的环境会让一个人堕落&#xff01; 而我已经在一个企业干了四年的功能…

【石上星光】context,go的上下文存储并发控制之道

目录 1 引言2 What&#xff1f;3 How&#xff1f; 3.1 用法一、上下文数据存储3.2 用法二、并发控制 3.2.1 场景1 主动取消3.2.2 场景2 超时取消 3.3 用法三、创建一个空Context&#xff08;emptyCtx&#xff09; 4 Why&#xff1f; 4.1 go中的上下文思想 4.1.1 上下文是什么…

技术小课堂:100%CC防护是怎么实现的?

大家好&#xff0c;今天我们深入探讨的是如何有效地实现CC攻击的100%防护&#xff0c;以及传统防护手段存在的局限性和我们的定制化解决方案的优势。 传统的CC防护措施通常依赖于全局性的访问频率控制或在防火墙级别设置固定的访问次数限制。这种方式看似简单直接&#xff0c;…

安全大脑与盲人摸象

21世纪是数字科技和数字经济爆发的时代&#xff0c;互联网正从网状结构向类脑模型进行进化&#xff0c;出现了结构和覆盖范围庞大&#xff0c;能够适应不同技术环境、经济场景&#xff0c;跨地域、跨行业的类脑复杂巨型系统。如腾讯、Facebook等社交网络具备的神经网络特征&…

[方案实操|数据技术]数据要素十大创新模式(1):基于区块链的多模态数据交易服务平台

“ 区块链以其公开共享、去中心化、不可篡改、可追溯和不可抵赖等优势&#xff0c;吸引了包括金融业、医疗业和政府部门等众多利益相关方的极大兴趣&#xff0c;被认为是解决数据安全交换问题的合适方案。” 武汉东湖大数据科技股份有限公司凭借基于区块链的多模态数据交易服务…

交换机的基本原理与配置_实验案例一:交换机的初始配置

1、实验环境 实验用具包括一台Cisco交换机&#xff0c;一台PC&#xff0c;一根Console 线缆。 2、需求描述 如图5.17所示&#xff0c;实验案例一的配置需求如下。 通过PC连接并配置一台Cisco交换机。在交换机的各个配置模式之间切换。将交换机主机的名称改为BDON 3、推荐步…

OpenHarmony应用编译 - 如何在源码中编译复杂应用(4.0-Release)

文档环境 开发环境&#xff1a;Windows 11 编译环境&#xff1a;Ubuntu 22.04 开发板型号&#xff1a;DAYU 200&#xff08;RK3568&#xff09; 系统版本&#xff1a;OpenHarmony-4.0-Release 功能简介 在 OpenHarmony 系统中预安装应用的 hap 包会随系统编译打包到镜像中&a…

C语言—每日选择题—Day68

第一题 1、运行以下C语言代码&#xff0c;输出的结果是&#xff08;&#xff09; #include <stdio.h> int main() {char *str[3] {"stra", "strb", "strc"};char *p str[0];int i 0;while(i < 3){printf("%s ",p);i;} retur…

Path Aggregation Network for Instance Segmentation

PANet 摘要1. 引言2.相关工作3.框架 PANet 最初是为 proposal-based 实例分割框架提出来的&#xff0c;mask 是实例的掩码&#xff0c;覆盖了物体包含的所有像素&#xff0c;proposal 在目标检测领域是可能存在目标的区域。在实例分割中&#xff0c;首先利用RPN(Region Proposa…

【并发】第四篇 原子操作系列-AtomicInteger原子操作类详解

导航 一. 简介二. 源码分析三. 原子操作原理三. 实际用途1. 标志位2. 唯一标识生成器3. 计数器一. 简介 AtomicInteger是Java中提供的一种线程安全的原子操作类,用来实现对整数类型的原子操作。它可以在多线程环境下保证对整数的原子性操作,而不需要使用synchronized关键字或…

小样本计数网络FamNet(Learning To Count Everything)

小样本计数网络FamNet(Learning To Count Everything) 大多数计数方法都仅仅针对一类特定的物体&#xff0c;如人群计数、汽车计数、动物计数等。一些方法可以进行多类物体的计数&#xff0c;但是training set中的类别和test set中的类别必须是相同的。 为了增加计数方法的可拓…

CloudCompare——win11配置CloudComPy

CloudComPy配置 1 基本环境介绍2 安装Anaconda2.1 下载anaconda2.2 安装anaconda2.3 配置镜像源2.4 更改虚拟环境的默认创建位置2.5 其他问题2.5.1 激活自己创建的环境提示&#xff1a;系统找不到指定的路径2.5.2 InvalidVersionSpecError: Invalid version spec: 2.72.5.3 卸载…

Hibernate框架的搭建

Hibernate框架的搭建 分层体系结构与持久化 三层体系结构 分层体系结构 指的是将系统的组件分隔到不同的层中&#xff0c;每一层中的组件应保持内聚性&#xff1b; 每一层都应与它下面的各层保持松散耦合。 层与层之间存在自上而下的依赖关系&#xff0c;即上层组件会访问下…

【一学就会】(一)C++编译工具链——基于VSCode的CMake、make与g++简单理解与应用示例

目录 一、CMake、make与g 1、名词辨析 2、孰优孰劣 二、应用示例 1、工具类安装与配置 1&#xff09;VSCode安装与配置 2&#xff09;CMake下载与安装 3&#xff09;MinGW-W64下载与安装 A、科学上网法 B、无需科学上网法 4&#xff09;VSCode推荐插件 A、c/c编译环…

nandgame中的Tokenize(标记化)

题目说明&#xff1a; "Tokenize" "标记化"标记器预先配置为识别数字和符号 。请配置标记器以额外识别符号减号 - 和括号 ( 和 )。您可以编辑源代码区域中的代码以测试它的标记化。level help 我们将构建一种高级编程语言。 高级语言具有更加人性化和灵…

K8s-Ingress Nginx-Day 08

1. 什么是Ingress 官方文档&#xff1a;https://kubernetes.io/zh-cn/docs/concepts/services-networking/ingress/#what-is-ingress Ingress 是 kubernetes API 中的标准资源类型之一&#xff0c;主要是k8s官方在维护。 2. Ingress的作用 Ingress 提供从集群外部到集群内服务…

NAT转换是怎么工作的?

前言 对象: 服务器S&#xff0c;NAT设备&#xff0c;用户设备C1&#xff0c;用户设备C2 用户C1向服务器S发起一个HTTP请求&#xff0c;经过NAT转化&#xff0c;服务器收到并作出响应&#xff0c;用户C1收到响应。 问题来了&#xff0c;NAT是怎么知道这个响应是给用户C1而不是…