DAY33 简单神经网络

news2025/5/31 20:03:06

你需要自行了解下MLP的概念。

你需要知道

  1. 梯度下降的思想
  2. 激活函数的作用
  3. 损失函数的作用
  4. 优化器
  5. 神经网络的概念

神经网络由于内部比较灵活,所以封装的比较浅,可以对模型做非常多的改进,而不像机器学习三行代码固定。

1. 神经网络的概念 (The concept of neural networks)

您可以把神经网络想象成一个由许多相互连接的“神经元”(neurons)组成的系统,它模仿了人类大脑处理信息的方式。这些神经元组织成层(layers):

  • 输入层 (Input Layer): 接收原始数据,比如一张图片的像素值,或者鸢尾花数据集中的花瓣、花萼的长度和宽度。
  • 隐藏层 (Hidden Layers): 在输入层和输出层之间,负责进行大部分的计算和特征提取。一个神经网络可以有多个隐藏层。层数越多,网络通常能学习更复杂的模式。
  • 输出层 (Output Layer): 输出最终结果,比如图片的分类(猫或狗),或者鸢尾花的种类。

每个连接都有一个相关的“权重”(weight),这个权重决定了前一个神经元对后一个神经元的影响程度。训练神经网络的过程,就是不断调整这些权重,使得网络能够对输入数据做出正确的预测。

2. 梯度下降的思想 (The idea of gradient descent)

梯度下降是一种优化算法,用于寻找函数(比如损失函数,后面会讲到)的最小值。想象一下你在一座山上,想要走到山谷的最低点,但周围有大雾,你看不清路。梯度下降就像你每走一步都选择当前位置最陡峭的下坡方向,这样就能最快到达谷底。

在神经网络中,“山”就是损失函数,我们要找的“谷底”就是损失函数的最小值点,对应的神经网络参数(主要是权重)就是最优的参数。梯度下降通过计算损失函数对于每个参数的“梯度”(可以理解为斜率或变化率),然后沿着梯度的反方向更新参数,从而逐渐减小损失。学习率(learning rate)是梯度下降中的一个重要超参数,它控制了每一步更新参数的幅度。

3. 损失函数的作用 (The role of loss functions)

损失函数(Loss Function),也叫代价函数(Cost Function),用来衡量神经网络模型预测结果与真实标签之间的差异。它的输出值(损失值)越大,表示模型的预测越不准确;损失值越小,表示模型预测越准确。

训练神经网络的目标就是最小化损失函数。通过梯度下降等优化算法,我们不断调整网络的权重,使得损失值越来越小。

常见的损失函数有:

  • 均方误差 (Mean Squared Error, MSE): 常用于回归问题(预测连续值)。
  • 交叉熵损失 (Cross-Entropy Loss): 常用于分类问题(预测离散类别)。您在代码中看到的 nn.CrossEntropyLoss() 就是这个。

4. 激活函数的作用 (The role of activation functions)

激活函数(Activation Function)被应用于神经网络中每个神经元的输出。它们的主要作用是给神经网络引入非线性(non-linearity)。

如果不用激活函数,或者只用线性激活函数(比如f(x)=x),那么无论神经网络有多少层,它本质上都只是一个线性模型,无法学习复杂的数据模式。而非线性激活函数使得神经网络能够拟合各种复杂的非线性关系。

常见的激活函数有:

  • Sigmoid: 将输出压缩到0和1之间,常用于二分类问题的输出层。
  • Tanh (双曲正切): 将输出压缩到-1和1之间。
  • ReLU (Rectified Linear Unit): 计算公式是 f(x) = max(0, x)。它非常简单高效,是目前最常用的激活函数之一。您在代码中看到的 nn.ReLU() 就是这个。
  • Softmax: 常用于多分类问题的输出层,它能将输出层的原始分数转换成概率分布,使得所有输出类别的概率和为1。

5. 优化器 (Optimizers)

优化器(Optimizer)是实现梯度下降思想的具体算法。它根据损失函数计算出的梯度来更新神经网络的权重,目标是最小化损失函数。

除了基本的随机梯度下降(Stochastic Gradient Descent, SGD,您代码中用的 optim.SGD),还有许多更高级的优化器,它们在SGD的基础上做了一些改进,试图更快、更稳定地找到损失函数的最小值。

常见的优化器有:

  • SGD (Stochastic Gradient Descent): 基本的梯度下降。
  • Momentum: 在SGD的基础上引入了动量,可以加速收敛并减少震荡。
  • AdaGrad (Adaptive Gradient): 对不同参数使用不同的学习率。
  • RMSprop (Root Mean Square Propagation): 也是自适应学习率的一种。
  • Adam (Adaptive Moment Estimation): 结合了Momentum和RMSprop的优点,是目前非常流行和常用的优化器之一。您代码中注释掉的 optim.Adam 就是这个。

简单来说,神经网络通过损失函数来衡量预测的好坏,通过梯度下降的思想,在优化器的帮助下,不断调整神经元之间的连接权重,而激活函数则赋予了网络学习复杂模式的能力。

数据的准备

# 仍然用4特征,3分类的鸢尾花数据集作为我们今天的数据集
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import numpy as np

# 加载鸢尾花数据集
iris = load_iris()
X = iris.data  # 特征数据
y = iris.target  # 标签数据
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 打印下尺寸
print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)

归一化数据

# 归一化数据,神经网络对于输入数据的尺寸敏感,归一化是最常见的处理方式
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test) #确保训练集和测试集是相同的缩放

改变成张量

# 将数据转换为 PyTorch 张量,因为 PyTorch 使用张量进行训练
# y_train和y_test是整数,所以需要转化为long类型,如果是float32,会输出1.0 0.0
X_train = torch.FloatTensor(X_train)
y_train = torch.LongTensor(y_train)
X_test = torch.FloatTensor(X_test)
y_test = torch.LongTensor(y_test)

模型架构定义

定义一个简单的全连接神经网络模型,包含一个输入层、一个隐藏层和一个输出层。

定义层数+定义前向传播顺序

class MLP(nn.Module): # 定义一个多层感知机(MLP)模型,继承父类nn.Module
    def __init__(self): # 初始化函数
        super(MLP, self).__init__() # 调用父类的初始化函数
 # 前三行是八股文,后面的是自定义的

        self.fc1 = nn.Linear(4, 10)  # 输入层到隐藏层
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(10, 3)  # 隐藏层到输出层
# 输出层不需要激活函数,因为后面会用到交叉熵函数cross_entropy,交叉熵函数内部有softmax函数,会把输出转化为概率

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# 实例化模型
model = MLP()

或者

    # def forward(self,x): #前向传播
    #     x=torch.relu(self.fc1(x)) #激活函数
    #     x=self.fc2(x) #输出层不需要激活函数,因为后面会用到交叉熵函数cross_entropy
    #     return x

定义损失函数和优化器

# 分类问题使用交叉熵损失函数
criterion = nn.CrossEntropyLoss()

# 使用随机梯度下降优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)

# # 使用自适应学习率的化器
# optimizer = optim.Adam(model.parameters(), lr=0.001)

开始循环训练

实际上在训练的时候,可以同时观察每个epoch训练完后测试集的表现:测试集的loss和准确度

# 训练模型
num_epochs = 20000 # 训练的轮数

# 用于存储每个 epoch 的损失值
losses = []

for epoch in range(num_epochs): # range是从0开始,所以epoch是从0开始
    # 前向传播
    outputs = model.forward(X_train)   # 显式调用forward函数
    # outputs = model(X_train)  # 常见写法隐式调用forward函数,其实是用了model类的__call__方法
    loss = criterion(outputs, y_train) # output是模型预测值,y_train是真实标签

    # 反向传播和优化
    optimizer.zero_grad() #梯度清零,因为PyTorch会累积梯度,所以每次迭代需要清零,梯度累计是那种小的bitchsize模拟大的bitchsize
    loss.backward() # 反向传播计算梯度
    optimizer.step() # 更新参数

    # 记录损失值
    losses.append(loss.item())

    # 打印训练信息
    if (epoch + 1) % 100 == 0: # range是从0开始,所以epoch+1是从当前epoch开始,每100个epoch打印一次
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

如果你重新运行上面这段训练循环,模型参数、优化器状态和梯度会继续保留,导致训练结果叠加,模型参数和优化器状态(如动量、学习率等)不会被重置。这会导致训练从之前的状态继续,而不是从头开始
不会重置

import matplotlib.pyplot as plt
# 可视化损失曲线
plt.plot(range(num_epochs), losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.show()

在这里插入图片描述
浙大疏锦行-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2391901.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OBOO鸥柏丨2025年鸿蒙生态+国产操作系统触摸屏查询一体机核心股

在信创产业蓬勃发展的当下,OBOO鸥柏积极响应纯国产化号召,推出基于华为鸿蒙HarmonyOS操作系统的触摸屏查询一体机及室内外场景广告液晶显示屏一体机上市,OBOO鸥柏品牌旗下显示产品均采用国产芯片,接入终端控制端需支持安卓Windows…

【观成科技】Ymir勒索软件组织窃密木马RustyStealer加密通信分析

1.概述 Ymir勒索软件首次发现于2024年7月,采用ChaCha20加密算法对受害者文件进行加密,加密文件后缀为.6C5oy2dVr6。在攻击过程中,Ymir勒索组织利用名为RustyStealer的窃密木马获取受害企业的账号凭证,为后续横向移动和权限提升奠…

Vuer开源程序 是一个轻量级的可视化工具包,用于与动态 3D 和机器人数据进行交互。它支持 VR 和 AR,可以在移动设备上运行。

​一、软件介绍 文末提供程序和源码下载 Vuer开源程序 是一个轻量级的可视化工具包,用于与动态 3D 和机器人数据进行交互。它支持 VR 和 AR,可以在移动设备上运行。 二、Our features include: 我们的功能包括: light-weight and performa…

短视频一键搬运 v1.7.1|短视频无水印下载 一键去重

短视频一键搬运是一款全自动智能处理软件,专为短视频创作者设计。它自带去水印、改MD5码、视频去重、视频编辑等功能,能够高效处理大量视频,解放双手并降低成本。该软件支持从多个短视频平台无缝提取视频并去除水印,同时检测敏感词…

海上石油钻井平台人员安全管控解决方案

一、行业挑战与需求分析 海上钻井平台面临复杂环境风险(如易燃易爆、金属干扰、极端气象)和人员管理难题(如定位模糊、应急响应延迟)。传统RFID或蓝牙定位技术存在精度不足(1-5米)、抗干扰能力差等问题&am…

TEASER-plusplu Windows Mingw编译

编译记录: 1.下载该库 v2.0 链接1:https://github.com/MIT-SPARK/TEASER-plusplus 连接2:https://github.com/MIT-SPARK/TEASER-plusplus/releases 2.下载 googletest 链接:https://github.com/google/googletest/releases?page2…

tryhackme——Data Exfiltration

文章目录 一、网络拓扑二、数据泄露分类2.1 传统数据泄露2.2 C2通信2.3 隧道 三、隧道3.1 Exfiltration using TCP socket3.2 Exfiltration using SSH3.3 Exfiltrate using HTTP(S)HTTP隧道 3.4 Exfiltration using ICMP3.4.1 ICMP数据包结构3.4.2 MSF实现ICMP数据泄露3.4.3 IC…

阿里云服务器采用crontab定时任务使acme.sh全自动化申请续签免费SSL证书,并部署在Linux宝塔网站和雷池WAF

阿里云服务器安装Linux宝塔面板用于部署网站,又安装了雷池WAF用于防护网站,网站访问正常。可以参考文章:Linux服务器安装Linux宝塔面板并部署wordpress网站以及雷池WAF 本文介绍使用 acme.sh 通过 DNS API 全自动申请和续签免费Let’s Encry…

【华为鸿蒙电脑】首款鸿蒙电脑发布:MateBook Fold 非凡大师 MateBook Pro,擎云星河计划启动

文章目录 前言一、HUAWEI MateBook Fold 非凡大师(一)非凡设计(二)非凡显示(三)非凡科技(四)非凡系统(五)非凡体验 二、HUAWEI MateBook Pro三、预热&#xf…

SpringBoot Controller接收参数方式, @RequestMapping

一. 通过原始的HttpServletRequest对象获取请求参数 二. 通过Spring提供的RequestParam注解,将请求参数绑定给方法参数 三. 如果请求参数名与形参变量名相同,直接定义方法形参即可接收。(省略RequestParam) 四. JSON格式的请求参数(POST、PUT) 主要在PO…

端午节互动网站

端午节互动网站 项目介绍 这是一个基于 Vue 3 Vite 开发的端午节主题互动网站,旨在通过有趣的交互方式展示中国传统端午节文化。网站包含三个主要功能模块:端午节介绍、互动包粽子游戏和龙舟竞赛游戏。 预览网站:https://duanwujiekuaile…

react-color-palette源码解析

项目中用到了react-color-palette组件,以前对第三方组件都是不求甚解,这次想了解一下其实现细节。 简介 react-color-palette 是一个用于创建颜色调色板的 React 组件。它提供了一个简单易用的接口,让开发者可以轻松地创建和管理颜色调色板。…

在 Ubuntu 上安装 NVM (Node Version Manager) 的步骤

NVM (Node Version Manager) 是一个用于管理多个 Node.js 版本的工具,它允许您在同一台设备上安装、切换和管理不同版本的 Node.js。以下是在 Ubuntu 上安装 NVM 的详细步骤: 安装前准备 可先在windows上安装ubuntu 参考链接:https://blog.…

重温经典算法——插入排序

版权声明 本文原创作者:谷哥的小弟作者博客地址:http://blog.csdn.net/lfdfhl 基本原理 插入排序是一种基于元素逐步插入的简单排序算法,其核心思想是将待排序序列分为已排序和未排序两部分,每次从未排序部分取出第一个元素&…

塔能科技:为多行业工厂量身定制精准节能方案

在当今追求可持续发展的时代,工厂能耗精准节能成为众多企业关注的焦点。塔能科技凭借先进的技术和丰富的经验,服务于广泛的行业客户,其中55.5%来自世界500强和上市公司。针对不同行业工厂的特点和需求,塔能提供了一系列行之有效的…

【实证分析】上市公司全要素生产率+5种测算方式(1999-2024年)

上市公司的全要素生产率(TFP)衡量企业在资本、劳动及中间投入之外,通过技术进步、管理效率和规模效应等因素提升产出的能力。与单纯的劳动生产率或资本生产率不同,TFP综合反映了企业创新能力、资源配置效率和组织优化水平&#xf…

弥散制氧机工作机制:高原低氧环境的氧浓度重构技术

弥散制氧机通过空气分离与智能扩散技术,将氧气均匀分布于封闭或半封闭空间,实现环境氧浓度的主动调控。其核心在于 “分子筛吸附动态均布智能反馈” 的协同作用机制,为高原、矿井、医疗等场景提供系统性氧环境解决方案。 一、空气分离&#x…

[Python] 避免 PyPDF2 写入 PDF 出现黑框问题:基于语言自动匹配系统字体的解决方案

在使用 Python 操作 PDF 文件时,尤其是在处理中文、日语等非拉丁字符语言时,常常会遇到一个令人头疼的问题——文字变成“黑框”或“方块”,这通常是由于缺少合适的字体支持所致。本文将介绍一种自动选择系统字体的方式,结合 PyPDF2 模块解决此类问题。 一、问题背景:黑框…

《基于Keepalived+LVS+Web+NFS的高可用集群搭建》

目 录 1 项目概述 1.1 项目背景 1.2 项目功能 2 项目的部署 2.1 部署环境介绍 2.2 项目的拓扑结构 2.3 项目环境调试 2.4 项目的部署 2.5 项目功能的验证 2.6 项目对应服务使用的日志 3 项目的注意事项 3.1 常见问题与解决方案 3.2 项目适用背…

时间序列预测算法中的预测概率化笔记

文章目录 1 预测概率化的前情提要2 预测概率化的代码示例3 预测概率化在实际商业应用场景探索3.1 智能库存与供应链优化 1 预测概率化的前情提要 笔者看到【行业SOTA,京东首个自研十亿级时序大模型揭秘】提到: 预测概率化组件:由于大部分纯时…