神经网络代码入门解析

news2025/5/13 2:39:18

神经网络代码入门解析

import torch
import matplotlib.pyplot as plt

import random


def create_data(w, b, data_num):  # 数据生成
    x = torch.normal(0, 1, (data_num, len(w)))
    y = torch.matmul(x, w) + b  # 矩阵相乘再加b

    noise = torch.normal(0, 0.01, y.shape)  # 为y添加噪声
    y += noise

    return x, y

num = 500

true_w = torch.tensor([8.1, 2, 2, 4])
true_b = 1.1


X, Y = create_data(true_w, true_b, num)

# plt.scatter(X[:, 3], Y, 1)  # 画散点图 对X取全部的行的第三列,标签Y,点大小
# plt.show()

def data_provider(data, label, batchsize):  # 每次取batchsize个数据
    length = len(label)
    indices = list(range(length))
    # 这里需要把数据打乱
    random.shuffle(indices)

    for each in range(0, length, batchsize):
        get_indices = indices[each: each+batchsize]
        get_data = data[get_indices]
        get_label = label[get_indices]

        yield get_data, get_label  # 有存档点的return

batchsize = 16
# for batch_x, batch_y in data_provider(X, Y, batchsize):
#     print(batch_x, batch_y)
#     break

# 定义模型
def fun(x, w, b):
    pred_y = torch.matmul(x, w) + b
    return pred_y

# 定义loss
def maeLoss(pre_y, y):
    return torch.sum(abs(pre_y-y))/len(y)

# sgd(梯度下降)
def sgd(paras, lr):
    with torch.no_grad():  # 这部分代码不计算梯度
        for para in paras:
            para -= para.grad * lr  # 不能写成 para = para - paras.grad * lr !!!! 这句相当于要创建一个新的para,会导致报错
            para.grad.zero_()  # 将使用过的梯度归零


lr = 0.01
w_0 = torch.normal(0, 0.01, true_w.shape, requires_grad=True)
b_0 = torch.tensor(0.01, requires_grad=True)
print(w_0, b_0)

epochs = 50
for epoch in range(epochs):
    data_loss = 0
    for batch_x, batch_y in data_provider(X, Y, batchsize):
        pred_y = fun(batch_x, w_0, b_0)
        loss = maeLoss(pred_y, batch_y)
        loss.backward()
        sgd([w_0, b_0], lr)
        data_loss += loss

    print("epoch %03d: loss: %.6f" % (epoch, data_loss))

print("真实函数值:", true_w, true_b)
print("训练得到的函数值:", w_0, b_0)


idx = 0
plt.plot(X[:, idx].detach().numpy(), X[:, idx].detach().numpy()*w_0[idx].detach().numpy()+b_0.detach().numpy())
plt.scatter(X[:, idx].detach().numpy(), Y, 1)
plt.show()

逐步分析代码

1.数据生成

image-20250301120222530

首先设计一个函数create_data,提供我们所需要的数据集的x与y

def create_data(w, b, data_num):  # 数据生成
    x = torch.normal(0, 1, (data_num, len(w)))  # 生成特征数据,形状为 (data_num, len(w))
    y = torch.matmul(x, w) + b  # 计算目标值 y = x * w + b

    noise = torch.normal(0, 0.01, y.shape)  # 生成噪声,形状与 y 相同
    y += noise  # 为 y 添加噪声,模拟真实数据中的随机误差

    return x, y
  • torch.normal() 生成一个张量

    • torch.normal(0, 1, (data_num, len(w))):生成一个形状为 (data_num, len(w)) 的张量,其中的元素是从均值为 0、标准差为 1 的正态分布中随机采样的。
  • torch.matmul() 让矩阵相乘

    matmul: matrix multiply

  • 再使用torch.normal()生成一个张量,添加到y上,相当于为y添加了随机的噪声

    噪声的引入是为了模拟真实数据中的随机误差,使生成的数据更接近现实场景。

2.设计一个数据加载器

def data_provider(data, label, batchsize):  # 每次取 batchsize 个数据
    length = len(label)
    indices = list(range(length))
    random.shuffle(indices)  # 打乱数据顺序,避免模型学习到顺序特征

    for each in range(0, length, batchsize):
        get_indices = indices[each: each+batchsize]  # 获取当前批次的索引
        get_data = data[get_indices]  # 获取当前批次的数据
        get_label = label[get_indices]  # 获取当前批次的标签

        yield get_data, get_label  # 返回当前批次的数据和标签

data_provider可以分批提供数据,并通过yield来返回已实现记忆功能

首先把list y顺序打乱,这样就相当于从生成的训练集y中随机读取,若不打乱数据,可能造成训练结果的不理想

打乱数据可以避免模型在训练过程中学习到数据的顺序特征,从而提高模型的泛化能力。

之后分段遍历打乱的y,返回对应的局部的数据集来给神经网络进行训练

3.定义模型函数

image-20250301122853184

def fun(x, w, b):
    pred_y = torch.matmul(x, w) + b  # 计算预测值 y = x * w + b
    return pred_y

fun(x, w, b) 是一个线性模型,形式为 y = x * w + b,其中 x 是输入特征,w 是权重,b 是偏置。

4.定义Loss函数

image-20250301122958888

def maeLoss(pre_y, y):
    return torch.sum(abs(pre_y - y)) / len(y)  # 计算平均绝对误差 (MAE)
  • maeLoss 是平均绝对误差(Mean Absolute Error, MAE),它计算预测值 pre_y 和真实值 y 之间的绝对误差的平均值。
  • 公式为:MAE = (1/n) * Σ|pre_y - y|,其中 n 是样本数量。

5.梯度下降sgd函数

# sgd(梯度下降)
def sgd(paras, lr):
    with torch.no_grad():  # 这部分代码不计算梯度
        for para in paras:
            para -= para.grad * lr  # 不能写成 para = para - paras.grad * lr !!!! 这句相当于要创建一个新的para,会导致报错
            para.grad.zero_()  # 将使用过的梯度归零

这里需要使用torch.no_grad()来避免重复计算梯度

image-20250301123531781

在前向过程中已经累计过一次梯度了,如果在梯度下降过程中又累计了梯度,那么就会造成不必要的麻烦

PyTorch 会累积梯度,如果不手动清零,梯度会不断累积,导致参数更新错误。

para -= para.grad * lr就是将参数w修正的过程(w=w-(dy^/dw)*learningRate)

torch.no_grad() 是一个上下文管理器,用于禁用梯度计算。在参数更新时,禁用梯度计算可以避免不必要的计算和内存占用。

5.开始训练

epochs = 50
for epoch in range(epochs):
    data_loss = 0
    num_batches = len(Y) // batchsize  # 计算批次数量

    for batch_x, batch_y in data_provider(X, Y, batchsize):
        pred_y = fun(batch_x, w_0, b_0)  # 前向传播
        loss = maeLoss(pred_y, batch_y)  # 计算损失
        loss.backward()  # 反向传播
        sgd([w_0, b_0], lr)  # 更新参数
        data_loss += loss.item()  # 累积损失

    print("epoch %03d: loss: %.6f" % (epoch, data_loss / num_batches))  # 打印平均损失

先定义一个训练轮次epochs=50,表示训练50轮

在每轮训练中将loss记录下来,以此评价训练的效果

首先用data_provider来获取数据集中随机的一部分

接着传入相应数据给模型函数,通过前向传播获得预测y值pred_y

调用Loss计算函数,获取这次的loss,再通过反向传播loss.backward()计算梯度

loss.backward() 是反向传播的核心步骤,用于计算损失函数对模型参数的梯度。

再通过梯度下降sgd([w_0, b_0], lr)来更新模型的参数

最终将这组数据的loss累加到这轮数据的loss中

6.结果绘制

idx = 0
plt.plot(X[:, idx].detach().numpy(), X[:, idx].detach().numpy() * w_0[idx].detach().numpy() + b_0.detach().numpy())  # 绘制预测直线
plt.scatter(X[:, idx].detach().numpy(), Y, 1)  # 绘制真实数据点
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2308553.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

TCP/IP 5层协议簇:网络层(IP数据包的格式、路由器原理)

目录 1. TCP/IP 5层协议簇 2. IP 三层包头协议 3. 路由器原理 4. 交换机和路由的对比 1. TCP/IP 5层协议簇 如下: 2. IP 三层包头协议 数据包如下:IP包头不是固定的,每一个数字是一个bit 其中数据部分是上层的内容,IP包头最…

echarts柱状图不是完全铺满容器,左右两边有空白

目录 处理前:echarts柱状图不是完全铺满容器,左右两边有空白处理前:通过调整 grid 组件配置处理后效果修改代码:1. 调整 grid 组件配置原理解决办法 2. 处理 xAxis 的 boundaryGap 属性原理解决办法 3. 调整 barMaxWidth 和 barMi…

ArcGIS Pro技巧实战:高效矢量化天地图地表覆盖图

在地理信息系统(GIS)领域,地表覆盖图的矢量化是一项至关重要的任务。天地图作为中国国家级的地理信息服务平台,提供了丰富且详尽的地表覆盖数据。然而,这些数据通常以栅格格式存在,不利于进行空间分析和数据…

西门子S7-1200比较指令

西门子S7-1200 PLC比较指令学习笔记 一、比较指令的作用 核心功能:用于比较两个数值的大小或相等性,结果为布尔值(True/False)。典型应用: 触发条件控制(如温度超过阈值启动报警)数据筛选&…

【AD】3-6 层次原理图

自上而下 1.放置-页面符号,并设置属性 2.放置-端口 可通过如下设置将自动生成关掉 3.放置-添加图纸入口,并创建图纸 自下而上 1.子图的原理图页设计 设计资原理图,复制网络标签,智能粘贴未PORT 2.新建主图原理图 创建框…

精品整理-2025 DeepSeek核心技术解析与实践资料合集(24份)

2025 DeepSeek核心技术解析与实践资料合集,共24份。 2025 DeepSeek 火爆背后的核心技术:知识蒸馏技术.pdf 2025 DeepSeek-R1详细解读:DeepSeek-R1-Zero和DeepSeek-R1分析.pdf 2025 DeepSeek-V3三个关键模块详细解读:MLAMoEMTP.pd…

【三维分割】LangSplat: 3D Language Gaussian Splatting(CVPR 2024 highlight)

论文:https://arxiv.org/pdf/2312.16084 代码:https://github.com/minghanqin/LangSplat 文章目录 一、3D language field二、回顾 Language Fields的挑战三、使用SAM学习层次结构语义四、Language Fields 的 3DGS五、开放词汇查询(Open-voca…

【HarmonyOS Next】鸿蒙应用折叠屏设备适配方案

【HarmonyOS Next】鸿蒙应用折叠屏设备适配方案 一、前言 目前应用上架华为AGC平台,都会被要求适配折叠屏设备。目前华为系列的折叠屏手机,有华为 Mate系列(左右折叠,华为 Mate XT三折叠),华为Pocket 系列…

数据库基础二(数据库安装配置)

打开MySQL官网进行安装包的下载 https://www.mysql.com/ 接着找到适用于windows的版本 下载版本 直接点击下载即可 接下来对应的内容分别是: 1:安装所有 MySQL 数据库需要的产品; 2:仅使用 MySQL 数据库的服务器; 3&a…

HumanPro逼真角色皮肤面部动画Blender插件V1.1版

https://www.youtube.com/watch?vnmV_jzgpIPM 本插件是关于HumanPro逼真角色皮肤面部动画Blender插件V1.1版,大小:2.9 MB,支持Blender 4.0 - 4.3版软件,支持Win系统,语言:英语。RRCG分享 HumanPro 是一款…

基于javaweb的SSM+Maven幼儿园管理系统设计和实现(源码+文档+部署讲解)

技术范围:SpringBoot、Vue、SSM、HLMT、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、小程序、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:免费功能设计、开题报告、任务书、中期检查PPT、系统功能实现、代码编写、论文编写和辅导、论…

PyTorch内存优化的10种策略总结:在有限资源环境下高效训练模型

在大规模深度学习模型训练过程中,GPU内存容量往往成为制约因素,尤其是在训练大型语言模型(LLM)和视觉Transformer等现代架构时。由于大多数研究者和开发者无法使用配备海量GPU内存的高端计算集群,因此掌握有效的内存优化技术变得尤为关键。本…

【湖北省计算机信息系统集成协会主办,多高校支持 | ACM出版,EI检索,往届已见刊检索】第二届边缘计算与并行、分布式计算国际学术会议(ECPDC 2025)

第二届边缘计算与并行、分布式计算国际学术会议(ECPDC 2025)将于2025年4月11日至13日在中国武汉盛大召开。本次会议旨在为边缘计算、并行计算及分布式计算领域的研究人员、学者和行业专家提供一个高水平的学术交流平台。 随着物联网、云计算和大数据技术…

硬件工程师入门教程

1.欧姆定律 测电压并联使用万用表测电流串联使用万用表,红入黑出 2.电阻的阻值识别 直插电阻 贴片电阻 3.电阻的功率 4.电阻的限流作用 限流电阻阻值的计算 单位换算关系 5.电阻的分流功能 6.电阻的分压功能 7.电容 电容简单来说是两块不连通的导体加上中间的绝…

性能测试监控工具jmeter+grafana

1、什么是性能测试监控体系? 为什么要有监控体系? 原因: 1、项目-日益复杂(内部除了代码外,还有中间件,数据库) 2、一个系统,背后可能有多个软/硬件组合支撑,影响性能的因…

DeepSeek如何快速开发PDF转Word软件

一、引言 如今,在线工具的普及让PDF转Word成为了一个常见需求,常见的PDF转Word工具有收费的WPS,免费的有PDFGear(详见:PDFGear:一款免费的PDF编辑、格式转化软件-CSDN博客),以及在线工具SmallP…

目标检测——数据处理

1. Mosaic 数据增强 Mosaic 数据增强步骤: (1). 选择四个图像: 从数据集中随机选择四张图像。这四张图像是用来组合成一个新图像的基础。 (2) 确定拼接位置: 设计一个新的画布(输入size的2倍),在指定范围内找出一个随机点(如…

基于springboot+vue的拖恒ERP-物资管理

开发语言:Java框架:springbootJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:…

spring结合mybatis多租户实现单库分表

实现单库分表-水平拆分 思路:student表数据量大,所以将其进行分表处理。一共有三个分表,分别是student0,student1,student2,在新增数据的时候,根据请求头中的meta-tenant参数决定数据存在哪张表…

YoloV8改进策略:Block改进|CBlock,Transformer式的卷积结构|即插即用

摘要 论文标题: SparseViT: Nonsemantics-Centered, Parameter-Efficient Image Manipulation Localization through Spare-Coding Transformer 论文链接: https://arxiv.org/pdf/2412.14598 官方GitHub: https://github.com/scu-zjz/SparseViT 这段代码出自SparseViT ,代码如…