逆天了!用Numpy开发深度学习框架,透视神经网络训练过程

news2025/7/10 20:55:49

哈喽,大家好。

今天给大家分享一个非常牛逼的开源项目,用Numpy开发了一个深度学习框架,语法与 Pytorch 基本一致。

今天以一个简单的卷积神经网络为例,分析神经网络训练过程中,涉及的前向传播、反向传播、参数优化等核心步骤的源码。

使用的数据集和代码已经打包好,文末有获取方式。

1. 准备工作

先准备好数据和代码。

1.1 搭建网络

首先,下载框架源码,地址:https://github.com/duma-repo/PyDyNet

git clone https://github.com/duma-repo/PyDyNet.git

搭建LeNet卷积神经网络,训练三分类模型。

PyDyNet目录直接创建代码文件即可。

from pydynet import nn

class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
        self.sigmoid = nn.Sigmoid()
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 3)

    def forward(self, x):
        x = self.conv1(x)
        x = self.sigmoid(x)
        x = self.avg_pool(x)
        
        x = self.conv2(x)
        x = self.sigmoid(x)
        x = self.avg_pool(x)
        
        x = x.reshape(x.shape[0], -1)
        
        x = self.fc1(x)
        x = self.sigmoid(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        x = self.fc3(x)
        
        return x

可以看到,网络的定义与Pytorch语法完全一样。

我提供的源代码里,提供了 summary 函数可以打印网络结构。

1.2 准备数据

训练数据使用Fanshion-MNIST数据集,它包含10个类别的图片,每个类别 6k 张。

为了加快训练,我只抽取了前3个类别,共1.8w张训练图片,做一个三分类模型。

1.3 模型训练

import pydynet
from pydynet import nn
from pydynet import optim

lr, num_epochs = 0.9, 10
optimizer = optim.SGD(net.parameters(),
                              lr=lr)
loss = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    net.train()
    for i, (X, y) in enumerate(train_iter):
        optimizer.zero_grad()
        y_hat = net(X)
        l = loss(y_hat, y)
        l.backward()
        optimizer.step()

        with pydynet.no_grad():
            metric.add(l.numpy() * X.shape[0],
                       accuracy(y_hat, y),
                       X.shape[0])

训练代码也跟Pytorch一样。

下面重点要做的就是深入模型训练的源码,来学习模型训练的原理。

2. train、no_grad和eval

模型开始训练前,会调用net.train

def train(self, mode: bool = True):
    set_grad_enabled(mode)
    self.set_module_state(mode)

可以看到,它会将grad(梯度)设置成True,之后创建的Tensor是可以带梯度的。Tensor带上梯度后,便会将其放入计算图中,等待求导计算梯度。

而下面的with no_grad(): 代码

class no_grad:
    def __enter__(self) -> None:
        self.prev = is_grad_enable()
        set_grad_enabled(False)

会将grad(梯度)设置成False,这样之后创建的Tensor不会放到计算图中,自然也不需要计算梯度,可以加快推理。

我们经常在Pytorch中看到net.eval()的用法,我们也顺便看一下它的源码。

def eval(self):
    return self.train(False)

可以看到,它直接调用train(False)来关闭梯度,效果与no_grad()类似。

所以,一般在训练前调用train打开梯度。训练后,调用eval关闭梯度,方便快速推理。

2. 前向传播

前向传播除了计算类别概率外,最最重要的一件事是按照前传顺序,将网络中的 tensor 组织成计算图,目的是为了反向传播时计算每个tensor的梯度。

tensor在神经网络中,不止用来存储数据,还用计算梯度、存储梯度。

以第一层卷积操作为例,来查看如何生成计算图。

def conv2d(x: tensor.Tensor,
           kernel: tensor.Tensor,
           padding: int = 0,
           stride: int = 1):
    '''二维卷积函数
    '''
    N, _, _, _ = x.shape
    out_channels, _, kernel_size, _ = kernel.shape
    pad_x = __pad2d(x, padding)
    col = __im2col2d(pad_x, kernel_size, stride)
    out_h, out_w = col.shape[-2:]
    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N * out_h * out_w, -1)
    col_filter = kernel.reshape(out_channels, -1).T
    out = col @ col_filter
    return out.reshape(N, out_h, out_w, -1).transpose(0, 3, 1, 2)

x是输入的图片,不需要记录梯度。kernel是卷积核的权重,需要计算梯度。

所以,pad_x = __pad2d(x, padding) 生成的新的tensor也是不带梯度的,因此也不需要加入计算图中。

kernel.reshape(out_channels, -1)产生的tensor则是需要计算梯度,也需要加入计算图中。

下面看看加入的过程:

def reshape(self, *new_shape):
    return reshape(self, new_shape)

class reshape(UnaryOperator):
    '''
    张量形状变换算子,在Tensor中进行重载

    Parameters
    ----------
    new_shape : tuple
        变换后的形状,用法同NumPy
    '''
    def __init__(self, x: Tensor, new_shape: tuple) -> None:
        self.new_shape = new_shape
        super().__init__(x)

    def forward(self, x: Tensor) -> np.ndarray:
        return x.data.reshape(self.new_shape)

    def grad_fn(self, x: Tensor, grad: np.ndarray) -> np.ndarray:
        return grad.reshape(x.shape)

reshape函数会返回一个reshape类对象,reshape类继承了UnaryOperator类,并在__init__函数中,调用了父类初始化函数。

class UnaryOperator(Tensor):
    def __init__(self, x: Tensor) -> None:
        if not isinstance(x, Tensor):
            x = Tensor(x)
        self.device = x.device
        super().__init__(
            data=self.forward(x),
            device=x.device,
            # 这里 requires_grad 为 True
            requires_grad=is_grad_enable() and x.requires_grad,
        )

UnaryOperator类继承了Tensor类,所以reshape对象也是一个tensor

UnaryOperator__init__函数中,调用Tensor的初始化函数,并且传入的requires_grad参数是True,代表需要计算梯度。

requires_grad的计算代码为is_grad_enable() and x.requires_gradis_grad_enable()已经被train设置为True,而x是卷积核,它的requires_grad也是True

class Tensor:
    def __init__(
        self,
        data: Any,
        dtype=None,
        device: Union[Device, int, str, None] = None,
        requires_grad: bool = False,
    ) -> None:
        if self.requires_grad:
            # 不需要求梯度的节点不出现在动态计算图中
            Graph.add_node(self)

最终在Tensor类的初始化方法中,调用Graph.add_node(self)将当前tensor加入到计算图中。

同理,下面使用requires_grad=Truetensor常见出来的新tensor都会放到计算图中。

经过一次卷积操作,计算图中会增加 6 个节点。

3. 反向传播

一次前向传播完成后,从计算图中最后一个节点开始,从后往前进行反向传播。

l = loss(y_hat, y)
l.backward()

经过前向网络一层层传播,最终传到了损失张量l

l为起点,从前向后传播,就可计算计算图中每个节点的梯度。

backward的核心代码如下:

def backward(self, retain_graph: bool = False):

    for node in Graph.node_list[y_id::-1]:
        grad = node.grad
        for last in [l for l in node.last if l.requires_grad]:
            add_grad = node.grad_fn(last, grad)
            
            last.grad += add_grad

Graph.node_list[y_id::-1]将计算图倒序排。

node是前向传播时放入计算图中的每个tensor

node.last 是生成当前tensor的直接父节点。

调用node.grad_fn计算梯度,并反向传给它的父节点。

grad_fn其实就是Tensor的求导公式,如:

class pow(BinaryOperator):
    '''
    幂运算算子,在Tensor类中进行重载

    See also
    --------
    add : 加法算子
    '''
    def grad_fn(self, node: Tensor, grad: np.ndarray) -> np.ndarray:
        if node is self.last[0]:
            return (self.data * self.last[1].data / node.data) * grad

return后的代码其实就是幂函数求导公式。

假设y=x^2x的导数为2x

4. 更新参数

反向传播计算梯度后,便可以调用优化器,更新模型参数。

l.backward()
optimizer.step()

本次训练我们用梯度下降SGD算法优化参数,更新过程如下:

def step(self):
    for i in range(len(self.params)):
        grad = self.params[i].grad + self.weight_decay * self.params[i].data
        self.v[i] *= self.momentum
        self.v[i] += self.lr * grad
        self.params[i].data -= self.v[i]
        if self.nesterov:
            self.params[i].data -= self.lr * grad

self.params是整个网络的权重,初始化SGD时传进去的。

step函数最核心的两行代码,self.v[i] += self.lr * grad 和 self.params[i].data -= self.v[i],用当前参数 - 学习速率 * 梯度更新当前参数

这是机器学习的基础内容了,我们应该很熟悉了。

一次模型训练的完整过程大致就串完了,大家可以设置打印语句,或者通过DEBUG的方式跟踪每一行代码的执行过程,这样可以更了解模型的训练过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/108886.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

制作圣诞帽其实特简单(附 Python 代码)

圣诞将至,虽然咱不过这洋节,但是热闹还是要凑一下的,相信已经有很多圣诞帽相关的周边在流传了,今天咱们就自己动手,给头像增加一个圣诞帽。 文章目录基础知识准备数字图像图像通道ROI和mask矩阵(Numpy&…

BaseAdapter实现的投票案例

BaseAdapter实现的投票案例 1.知识补充 android:descendantFocusability"blocksDescendants",关键是让谁先去获取焦点beforeDescendants:viewgroup会优先其子类控件而获取到焦点afterDescendants:viewgroup只有当其子类控件不需要获…

Flink Process Function

处理函数: ProcessFunction: 含有状态流的特性 处理函数面对的是数据流中的最基本元素: 数据事件 event, 状态 state, 时间 time 文章目录1.基本处理函数 ProcessFunction1.1 处理函数的功能和使用1.2 ProcessFunction 解析2.处理函数的分类2.1 按键分区处理函数 KeyedProces…

LaTex期刊模板下载与使用

1 LaTex期刊模板下载与使用 接上文介绍了LaTex的下载安装和基本语法使用规则。 上文地址:科研人快速入门LaTex到日常使用,下载安装配置,语法使用说明等 一般来说,LaTeX主要用在论文提交,书籍排版过程中,提…

Kubernetes:Pod

文章目录1、Pod 定义2、Pod 使用2.1、init 容器2.2、容器生命周期处理函数2.3、容器的探测2.3.1、探测机制2.3.2、探测结果2.3.3、探测类型startupProbereadinessProbelivenessProbe2.3.4、案例2.4、测试代码3、Pod 的部署3.1、Deployment3.2、DaemonSets3.3、静态 pod4、参考p…

我国金属包装行业企业数量下降 经济效益整体表现不佳 但亏损额减少

根据观研报告网发布的《中国金属包装市场发展趋势研究与未来投资预测报告(2022-2029年)》显示,金属包装是指采用金属薄板,针对不同用途制作的各种不同形式的薄壁包装容器,相较于其它包装,金属包装因为其材质特性,比一般…

3DEXPERIENCE平台2023新功能揭秘!Governance云端数据管理解决方案

3DEXPERIENCE平台更新版本已经与大家见面,今天众联亿诚与大家分享Governance新功能。 多年来,我们一直在寻找SOLIDWORKS数据管理的更优解决方案。但就是感觉很艰难,硬件投资是昂贵的,实施是资源密集型的,更重要的是&a…

【TypeScript】TS入门(一)

🐱个人主页:不叫猫先生 🙋‍♂️作者简介:前端领域新星创作者、华为云享专家、阿里云专家博主,专注于前端各领域技术,共同学习共同进步,一起加油呀! 💫系列专栏&#xff…

Hook原理

对于会Hook的人来说,Hook其实也就那么回事.对于没有Hook过的人来说,会感觉Hook很高大上(其实也没毛病). 那么今天我们就来探讨一些Hook的原理是什么. 我认为任何Hook都可以分为以下三步(简称WFH): 需要Hook的是什么,在哪里(后面简称Where). 寻找到Hook的地方.(后面简称Find)…

JavaScript基础(15)_数组

对象分为三种:内建对象、宿主对象、自定义对象。 内建对象 内建对象是指由ECMAScript事先提供的、不依赖于宿主环境的对象,这些对象在程序运行之前就已经存在,并可以直接在程序中任何地方任何时候拿来使用。常见的内建对象可以直接通过new调…

【JavaEE】Servlet

努力经营当下,直至未来明朗! 文章目录【Servlet】1.0Servlet概述写一个Servlet程序1. 创建项目2. 引入Servlet依赖3. 创建目录结构4. 编写代码5. 打包程序6. 部署程序7. 验证程序【Servlet 2.0】访问出错【小结】追求想要的一定很酷! 【Serv…

docker rootless安装

rootless 简介 rootless模式允许以非root用户身份运行Docker守护程序和容器,以减轻守护程序和容器运行时中的潜在漏洞。只要满足先决条件,即使在Docker守护程序安装期间,无根模式也不需要root特权。无根模式是Docker Engine v19.03中引入的一…

【俄罗斯方块】单机游戏-微信小程序项目开发入门

这是一个仿俄罗斯方块小游戏的微信小程序,只需要写一小段代码就实现出来了,有兴趣的同学完全可以自己动手开发,来看看实现过程是怎样的呢,边写边做,一起来回忆小时候玩过的经典俄罗斯方块游戏吧。 文章目录创建小程序页…

certbot生成证书,配置nginx,利用脚本自动续期

踩了大量坑,做下记录。以下适用于博主本人,但是未必会适用于所有人 单域名与泛域名证书生成 sudo certbot certonly --standalone --email 邮箱 -d 域名# 单域名certbot certonly --preferred-challenges dns --manual -d *.baidu.com(修改这里) --ser…

【檀越剑指大厂—Springboot】Springboot高阶

一.整体介绍 1.什么是 Springboot? Springboot 是一个全新的框架,简化 Spring 的初始搭建和开发过程,使用了特定的方式来进行配置,让开发人员不再需要定义样板化的配置。此框架不需要配置 xml,依赖于 maven 这样的构建系统。 …

嵌入式分享合集125

一、多层板PCB设计中电源平面相对地平面要进行内缩? 有一些人绘制的PCB,在GND层和电源层会进行一定程度的内缩设计,那么大家有没有想过为什么要内缩呢。 需要搞清楚这个问题,我们需要来先了解一个知识点,那就是“20H”…

matlab 功率谱分析

谱分析介绍 谱分析是一种用于研究函数的数学方法。在数学中,谱分析的基本概念是将函数分解成不同的频率成分,以便更好地理解其行为。这些频率成分可以表示为正弦或余弦函数的级数和,称为谱线。 谱分析常用于信号处理、音频信息处理和图像处…

Windows系统增强优化工具

计算机系统优化的作用很多,它可以清理WINDOWS临时文件夹中的临时文件,释放硬盘空间;可以清理注册表里的垃圾文件,减少系统错误的产生;它还能加快开机速度,阻止一些程序开机自动执行;还可以加快上…

数据也能开口说话?这次汇报,老板疯狂给我点赞

年底了,大家的工作汇报进行得怎么样了? 是不是少不了各种数据?饼图、柱形图、条形图、折线图、散点图有没有充斥在你的 PPT 中? 我们出版社的数据统计一般截止到 12 月中下旬,所以前两天,我已经做完了年终…

白话说Java虚拟机原理系列【第三章】:类加载器详解

文章目录jvm.dllBootstrapLoader:装载系统类ExtClassLoader:装载扩展类AppClassLoader:装载自定义类双亲委派模型类加载器加载类的方式类加载器特性类加载器加载字节码到JVM的过程自定义/第三方类加载器类加载器加载字节码到哪?Cl…