神经网络-万能近似定理的探索

news2025/9/16 18:36:48

神经网络-万能近似定理的探索

对于这个实验博主想说,其实真的很有必要去好好做一下,很重要的一个实验。

1.理论介绍

万能近似定理: ⼀个前馈神经⽹络如果具有线性层和⾄少⼀层具有 “挤压” 性质的激活函数(如 sigmoid 等),给定⽹络⾜够数量的隐藏单元,它可以以任意精度来近似任何从⼀个有限维空间到另⼀个有限维空间的 borel 可测函数。

在这里插入图片描述
在这里插入图片描述

我们可以通过两个 sigmoid 函数 (y = sigmoid(w⊤x + b)) ⽣成⼀个 tower,如图:
在这里插入图片描述
我们构造多个这样的 tower 近似任意函数:
在这里插入图片描述

2.数据集介绍

使用PyTorch库来执行数值计算,首先通过 torch.linspace 函数创建了一个从0到3的等差数列,其元素数量由变量 sample_num 决定。接着定义了一个多项式函数 function,它接受一个参数 x 并返回 x + x**3 + 3 的值。最后,代码通过将 x_data 作为输入应用 function 函数,计算并存储了每个数据点对应的函数值到 y_real 变量中,从而实现了对一系列数据点的生成。


sample_num=300
torch.manual_seed(10)

x_data=torch.linspace(0,3,sample_num)


def function(x):
  return x+x**3+3



y_real=function(x_data)

3.实验过程

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
所有代码如下:

#coding=gbk

import torch
from torch.autograd import Variable
from torch.utils import data
import matplotlib.pyplot as plt

neuron_num=200
batch=50
learn_rating=0.005
epoch=1000


#print(x_data)

sample_num=50
torch.manual_seed(10)

x_data=torch.linspace(0,3,sample_num)


def function(x):
  return x+x**3+3



y_real=function(x_data)
#print(y_real)

wz=torch.rand(neuron_num)
#print(wz)
w=torch.normal(0,1,size=(neuron_num,))
b=torch.normal(0,1,size=(neuron_num,))
w2=torch.normal(0,1,size=(neuron_num,))
b2=torch.normal(0,1,size=(1,))
#print(w,b)


def sampling(sample_num):
   
    #print(data_size)
    index_sequense=torch.randperm(sample_num)
    return index_sequense
def activation_function(x):
    return 1/(1+torch.sigmoid(-x))


def activation_function2(x):
   
   # print(x.size())
    a=torch.rand(x.size())
    for i in range(x.size(0)):
        if x[i]<=0:
            a[i]=0.0
        else:
            a[i]=x[i]
    return a
def hidden_layer(w,b,x):
   # print("x:",x)
    return activation_function2(w*x+b)

def  fully_connected_layer(w2,b2,x):
    return torch.dot(w2,x)+b2

def net(w,b,w2,b2,x):
    #print("w",w)
    #print("b",b)
    o1=hidden_layer(w,b,x)
    #print("o1",o1)
    #print("w2",w2)
    #print("b2",b2)
    o2=fully_connected_layer(w2,b2,o1)
 #   print("o2",o2)

    return o2

def get_grad(w,b,w2,b2,x,y_predict,y_real):
  

    o2=hidden_layer(w,b,x)

    l_of_w2=-(y_real-y_predict)*o2
    l_of_b2=-(y_real-y_predict)

    a=torch.rand(o2.size())
    for i in range(o2.size(0)):
        if o2[i]<=0:
            a[i]=0.0
        else:
            a[i]=1
    l_of_w=-(y_real-y_predict)*torch.dot(w2,a)*x
    l_of_b=-(y_real-y_predict)*torch.dot(w2,a)

    return l_of_w2,l_of_b2,l_of_w,l_of_b

def loss_function(y_predict,y_real):
    #print(y_predict,y_real)
    return torch.pow(y_predict-y_real,2)
loss_list=[]

torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train():
    global w,w2,b,b2
    index=0
    index_sequense=sampling(sample_num)
    for i in range(epoch):
        
        
   
        W_g=torch.zeros(neuron_num)
        b_g=torch.zeros(neuron_num)
        W2_g=torch.zeros(neuron_num)
        b2_g=torch.zeros(1)
        loss=torch.tensor([0.0])
           
        for k in range(batch):
              try:
                   # print("x",x_data[index],index)
                   # print("w",w)
                        y_predict=net(w,b,w2,b2,x_data[index_sequense[index]])

                        get_grad(w,b,w2,b2,x_data[index_sequense[index]],y_predict,y_real[index_sequense[index]])
                      
                        l_of_w2,l_of_b2,l_of_w,l_of_b= get_grad(w,b,w2,b2,x_data[index_sequense[index]],y_predict,y_real[index_sequense[index]])
                        W_g=W_g+l_of_w
                        b_g=b_g+l_of_b
                        b2_g=b2_g+l_of_b2
                        W2_g=W2_g+l_of_w2
                        loss=loss+loss_function(y_predict,y_real[index_sequense[index]])
                        index=index+1
              except:
                        index=0
                      #  index_sequense=sampling(sample_num)
        print("****************************loss is :",loss/batch)
        loss_list.append(loss/batch)
     
        W_g=W_g/batch
        b_g=b_g/batch
        b2_g=b2_g/batch
        W2_g=W2_g/batch
        w=w-learn_rating*W_g
        b=b-learn_rating*b_g
        w2=w2-learn_rating*W2_g
        b2=b2-learn_rating*b2_g
y_predict=net(w,b,w2,b2,x_data[0])
train()
y_predict=[]
for i in range(sample_num):
  y_predict.append(net(w,b,w2,b2,x_data[i]))

epoch_list=list(range(epoch))
plt.plot(epoch_list,loss_list,label='SGD')
plt.title("loss")
plt.legend()

plt.show()

plt.plot(x_data,y_real,label='real')
plt.plot(x_data,y_predict,label='predict')
plt.title(" Universal Theorem of Neural Networks")
plt.legend()
plt.show()
print(w,b,w2,b2)

4.总结

(1)相比较之下,tanh函数比sigmoid函数效果要好很多,但是如果,神经元数量足够,训练足够充分,两者效果会差不多。在训练不够充分,神经元不够多,学习率等条件限制下,tanh函数表现更好。
(2)神经元数量越多的情况下,模型性能在不断提升,但是学习越来越慢,如果采用随机梯度下降或者小批量梯度下降,最好增加神经元,否则,学习最终性能会受到很大限制。
(3)对于不同数量的神经元,最优学习率是在变化的。
(4)如果采用随机梯度下降或者小批量梯度下降,样本不能反应总体规律,那么网络要足够大,才能进行规律的总结,否则,最好采用批量梯度下降去处理问题。
(5)不同的超参数设置,最终学习的性能会有瓶颈,需要进行学习策略,超参数的调节,才能不断提高模型效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1823823.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于Matlab的细胞计数图像处理系统(GUI界面有报告) 【含Matlab源码 MX_003期】

简介&#xff1a; 本文旨在解决生物血细胞数目统计的挑战&#xff0c;提出了基于图像处理的综合方案。通过MATLAB平台&#xff0c;我们设计并实现了一套完整的细胞图像处理与分析流程。在预处理阶段&#xff0c;采用图像增强和阈值分割等方法&#xff0c;有效地提高了细胞图像的…

Swift开发——输出格式化字符

Swift语言是开发iOS和macOS等Apple计算机和移动设备系统应用程序的官方语言。Swift语言是一种类型安全的语言,语法优美自然,其程序从main.swift文件开始执行,程序代码按先后顺序执行,同一个工程的程序文件中的类和函数直接被main.swift文件调用,除了main.swift文件外,工程…

【数据挖掘-思考】分类和聚类

将芝麻和花生分开&#xff0c;是一个分类问题还是聚类问题? 显而易见的&#xff0c;在日常生活中&#xff0c;这是一个分类问题&#xff0c;在数据挖掘领域中&#xff0c;是否也是这样呢&#xff1f; 通义千问的回答&#xff1a; 在数据挖掘中&#xff0c;将芝麻和花生分开可以…

【C语言】14. qsort 的底层与模拟实现

一、回调函数 回调函数就是⼀个通过函数指针调用的函数。 把函数的指针&#xff08;地址&#xff09;作为参数传递给另⼀个函数&#xff0c;当这个指针被用来调用其所指向的函数时&#xff0c;被调用的函数就是回调函数。回调函数不是由该函数的实现方直接调用&#xff0c;而是…

思维+暴力,CF992D - Nastya and a Game

一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 992D - Nastya and a Game 二、解题报告 1、思路分析 这个题题目很吓人 因为看起来前缀和根本存不下&#xff0c;似乎没法算 这也提示我们似乎只需在小范围内枚举求解即可 题目的P / K SUM也保证了我们…

SQL中的UPDATE语句:别让你的数据“离家出走”

sql的update操作正式环境用的很少&#xff0c;但是在测试环境还是用的挺多的。 想象一下&#xff0c;你正在管理一个学校的数据库&#xff0c;其中有一个students表&#xff0c;记录着每个学生的信息。有一天&#xff0c;你接到通知说某个学生的年龄或成绩需要更新。这时&…

54.Python-web框架-Django-免费模板django-datta-able

1.Datta Able Django介绍 Detta Able Djiango是什么 Datta Able Django 是一个由AppSeed提供的开源Django管理面板&#xff0c;基于现代设计&#xff0c;为开发者提供了一流的功能和优雅的界面。它源自CodedThemes的高风格化Bootstrap 4模板——Datta Able Bootstrap Lite&…

云电脑有多好用?适合哪些人使用?

云电脑作为一种新型的计算模式&#xff0c;其应用场景广泛且多样&#xff0c;适合各类人群使用。云电脑适合什么人群使用&#xff1f;云电脑有哪些应用场景&#xff1f;有什么好的云电脑推荐&#xff1f;以下本文将详细探讨云电脑的主要应用场景及其适用人群的相关内容&#xf…

【YOLOv5/v7改进系列】改进池化层为RT-DETR的AIFI

一、导言 Real-Time DEtection TRansformer&#xff08;RT-DETR&#xff09;&#xff0c;是一种实时端到端目标检测器&#xff0c;克服了Non-Maximum Suppression&#xff08;NMS&#xff09;对速度和准确性的影响。通过设计高效的混合编码器和不确定性最小化查询选择&#xf…

缩窄route范围来提速本地打包的尝试

目录 为什么要缩窄route范围缩窄route的方式意外触发的重复构建重复构建的原因解决方案 为什么要缩窄route范围 对于一些大单页&#xff0c;单个router-view中可能包含上百个页面。但是开发的时候其实并不需要那么多调试那么多页面。 因此&#xff0c;为了节省不必要的打包和热…

【SpringBoot + Vue 尚庭公寓实战】地区信息管理接口实现(九)

【SpringBoot Vue 尚庭公寓实战】地区信息管理接口实现&#xff08;九&#xff09; 文章目录 【SpringBoot Vue 尚庭公寓实战】地区信息管理接口实现&#xff08;九&#xff09;1、业务说明2、数据逻辑模型3、接口实现3.1、查询省份信息列表3.2、根据省份ID查询城市信息列表3…

Python对象序列化库之dill使用详解

概要 在 Python 编程中,序列化(Serialization)和反序列化(Deserialization)是处理对象持久化和数据传输的常见任务。Python 提供了内置的 pickle 模块用于对象序列化,但它在处理复杂对象(如带有 lambda 函数、生成器和闭包的对象)时存在一定局限性。dill 库是 pickle …

本地运行大语言模型(LLMs)

用例 像PrivateGPT、llama.cpp、Ollama、GPT4All、llamafile 等项目的流行度凸显了本地&#xff08;在您自己的设备上&#xff09;运行大型语言模型&#xff08;LLMs&#xff09;的需求。 这至少有两个重要的好处&#xff1a; 1.隐私&#xff1a;您的数据不会发送给第三方&a…

【Java面试】十八、并发篇(中)

文章目录 1、什么是AQS2、ReentrantLock的实现原理2.1 原理2.2 其他补充点 3、synchronized和Lock有什么区别3.1 区别3.2 Demo代码3.3 signal方法的底层实现 4、死锁的产生与排查4.1 死锁产生的条件是什么4.2 死锁的排查 1、什么是AQS AQS&#xff0c;抽象队列同步器&#xff…

汇编:内联汇编和混合编程

C/C内联汇编 C/C 内联汇编&#xff08;Inline Assembly&#xff09;是一种在C或C代码中嵌入汇编语言指令的方法&#xff0c;以便在不离开C/C环境的情况下利用汇编语言的优势进行性能优化或执行特定的硬件操作。以下是一些详细的说明和示例&#xff0c;展示如何在C和C代码中使用…

springboot集成swagger、knife4j

1. 集成swagger2 1.1 引入依赖 <!-- https://mvnrepository.com/artifact/io.springfox/springfox-swagger2 --><dependency><groupId>io.springfox</groupId><artifactId>springfox-swagger2</artifactId><version>2.9.2</vers…

Redis和Docker

Redis 和 Docker 是两种不同的技术&#xff0c;它们各自解决不同的问题&#xff0c;但有时会一起使用以提供更高效和灵活的解决方案。 Redis 是一个开源的内存数据结构存储系统&#xff0c;可以用作数据库、缓存和消息代理。它设计为解决MySQL等关系型数据库在处理大量读写访问…

Python 植物大战僵尸游戏【含Python源码 MX_012期】

简介&#xff1a; "植物大战僵尸"&#xff08;Plants vs. Zombies&#xff09;是一款由PopCap Games开发的流行塔防游戏&#xff0c;最初于2009年发布。游戏的概念是在僵尸入侵的情境下&#xff0c;玩家通过种植不同种类的植物来保护他们的房屋免受僵尸的侵袭。在游…

重生之 SpringBoot3 入门保姆级学习(19、场景整合 CentOS7 Docker 的安装)

重生之 SpringBoot3 入门保姆级学习&#xff08;19、场景整合 CentOS7 Docker 的安装&#xff09; 6、场景整合6.1 Docker 6、场景整合 6.1 Docker 官网 https://docs.docker.com/查看自己的 CentOS配置 cat /etc/os-releaseStep 1: 安装必要的一些系统工具 sudo yum insta…

react 0至1 案例

/*** 导航 Tab 的渲染和操作** 1. 渲染导航 Tab 和高亮* 2. 评论列表排序* 最热 > 喜欢数量降序* 最新 > 创建时间降序* 1.点击记录当前type* 2.通过记录type和当前list中的type 匹配*/ import ./App.scss import avatar from ./images/bozai.png import {useState} …