pytorch从零开始搭建神经网络

news2025/7/15 1:21:50

目录

基本流程

一、数据处理

二、模型搭建

三、定义代价函数&优化器

四、训练

附录

nn.Sequential

nn.Module

model.train() 和 model.eval() 

损失

图神经网络


基本流程

《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili

1. 数据预处理(Dataset、Dataloader)

2. 模型搭建(nn.Module)

3. 损失&优化(loss、optimizer)

4. 训练(forward、backward)

一、数据处理

对于数据处理,最为简单的⽅式就是将数据组织成为⼀个 。

但许多训练需要⽤到mini-batch,直 接组织成Tensor不便于我们操作。

pytorch为我们提供了DatasetDataloader两个类来方便的构建。

torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)

 

二、模型搭建

搭建一个简易的神经网络

除了采用pytorch自动梯度的方法来搭建神经网络,还可以通过构建一个继承了torch.nn.Module的新类,来完成forward和backward的重写。

# 神经网络搭建
import torch
from torch.autograd import Varible 
batch_n = 100 
hidden_layer = 100 
input_data = 1000
output_data = 10 

class Model(torch.nn.Module):
	def __init__(self):
		super(Model,self).__init__()
		
	def forward(self,input,w1,w2):
		x = torch.mm(input,w1)
		x = torch.clamp(x,min = 0)
		x = torch.mm(x,w2)
		
 	def backward(self):
 		pass

model = Model()

#训练
x = Variable(torch.randn(batch_n,input_data))

一点一点地看:

import torch

dtype = torch.float
device = torch.device("cpu")

N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6

tensor 写一个粗糙版本(后面陆陆续续用pytorch提供的方法)

for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    if t % 100 == 99:
        print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

三、定义代价函数&优化器

Autograd

for t in range(500):
    y_pred = x.mm(w1).clamp(min=0).mm(w2)

    loss = (y_pred - y).pow(2).sum()
    if t % 100 == 99:
        print(t, loss.item())

    loss.backward()

    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        w1.grad.zero_()
        w2.grad.zero_()

对于需要计算导数的变量(w1和w2)创建时设定requires_grad=True,之后对于由它们参与计算的变量(例如loss),可以使用loss.backward()函数求出loss对所有requires_grad=True的变量的梯度,保存在w1.grad和w2.grad中。

在迭代w1和w2后,即使用完w1.grad和w2.grad后,使用zero_函数清空梯度。
 

nn

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-4
for t in range(500):
    y_pred = model(x)

    loss = loss_fn(y_pred, y)
    if t % 100 == 99:
        print(t, loss.item())

    model.zero_grad()

    loss.backward()

    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

optim

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    y_pred = model(x)

    loss = loss_fn(y_pred, y)
    if t % 100 == 99:
        print(t, loss.item())

    optimizer.zero_grad()

    loss.backward()

    optimizer.step()

四、训练

迭代进行训练以及测试,其中训练的函数train里就保存了进行梯度下降求解的方法

# 定义训练函数,需要
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    # 从数据加载器中读取batch(一次读取多少张,即批次数),X(图片数据),y(图片真实标签)。
    for batch, (X, y) in enumerate(dataloader):
        # 将数据存到显卡
        X, y = X.to(device), y.to(device)

        # 得到预测的结果pred
        pred = model(X)

        # 计算预测的误差
        # print(pred,y)
        loss = loss_fn(pred, y)

        # 反向传播,更新模型参数
        optimizer.zero_grad() #梯度清零
        loss.backward() #反向传播
        optimizer.step() #更新参数

        # 每训练10次,输出一次当前信息
        if batch % 10 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

设置为测试模型并设置不计算梯度,进行测试数据集的加载,判断预测值与实际标签是否一致,统一正确信息个数

# 将模型转为验证模式
model.eval()
# 测试时模型参数不用更新,所以no_gard()
with torch.no_grad():
    # 加载数据加载器,得到里面的X(图片数据)和y(真实标签)
    for X, y in dataloader:
        加载数据
        pred = model(X)#进行预测
        # 预测值pred和真实值y的对比
        test_loss += loss_fn(pred, y).item()
        # 统计预测正确的个数
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()#返回相应维度的最大值的索引
test_loss /= size
correct /= size
print(f"correct = {correct}, Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


 


附录


mark一下很有用的博客

pytorch代码编写入门 - 简书

推荐给大家!Pytorch编写代码基本步骤思想 - 知乎

用pytorch实现神经网络_徽先生的博客-CSDN博客_pytorch 神经网络


Dataset、DataLoader

① 创建一个 Dataset 对象
② 创建一个 DataLoader 对象
③ 循环这个 DataLoader 对象,将xx, xx加载到模型中进行训练

DataLoader详解_sereasuesue的博客-CSDN博客_dataloader

都会|可能会_深入浅出 Dataset 与 DataLoader

Pytorch加载自己的数据集(使用DataLoader读取Dataset)_l8947943的博客-CSDN博客_pytorch dataloader读取数据

可以直接调用的数据集

https://www.pianshen.com/article/9695297328/


nn.Sequential

pytorch教程之nn.Sequential类详解——使用Sequential类来自定义顺序连接模型_LoveMIss-Y的博客-CSDN博客_sequential类

nn.Module

torch.nn.Module是torch.nn.functional中方法的实例化

pytorch教程之nn.Module类详解——使用Module类来自定义模型_LoveMIss-Y的博客-CSDN博客_torch.nn.module

对应Sequential的三种包装方式,Module有三种写法


model.train() 和 model.eval() 

model.train()
for epoch in range(epoch):
    for train_batch in train_loader:
        ...
    zhibiao = test(epoch, test_loader, model)
        
def test(epoch, test_loader, model):
    model.eval()
    for test_batch in test_loader:
        ...
    return zhibiao

【Pytorch】model.train() 和 model.eval() 原理与用法_想变厉害的大白菜的博客-CSDN博客_pytorch train()

 

pytroch:model.train()、model.eval()的使用_像风一样自由的小周的博客-CSDN博客_model.train()放在程序的哪个位置

 

model = ...
dataset = ...
loss_fun = ...

# training
lr=0.001
model.train()
for x,y in dataset:
 model.zero_grad()
 p = model(x)
 l = loss_fun(p, y)
 l.backward()
 for p in model.parameters():
  p.data -= lr*p.grad
 
# evaluating
sum_loss = 0.0
model.eval()
with torch.no_grad():
 for x,y in dataset:
  p = model(x)
  l = loss_fun(p, y)
  sum_loss += l
print('total loss:', sum_loss)

https://www.jb51.net/article/211954.htm


损失

MAE:

import torch
from torch.autograd import Variable
x = Variable(torch.randn(100, 100))
y = Variable(torch.randn(100, 100))
loos_f = torch.nn.L1Loss()
loss = loos_f(x,y)

MSE:

import torch
from torch.autograd import Variable
x = Variable(torch.randn(100, 100))
y = Variable(torch.randn(100, 100))
loos_f = torch.nn.MSELoss()#定义
loss = loos_f(x, y)#调用

torch.nn中常用的损失函数及使用方法_加油上学人的博客-CSDN博客_nn损失函数


优化器

pytorch 优化器调参以及正确用法 - 简书


训练&测试

基于pytorch框架下的一个简单的train与test代码_黎明静悄悄啊的博客-CSDN博客


图神经网络

1. GCN、GAT

图神经网络及其Pytorch实现_jiangchao98的博客-CSDN博客_pytorch 图神经网络

2. 用DGL

PyTorch实现简单的图神经网络_梦家的博客-CSDN博客_pytorch图神经网络

一文看懂图神经网络GNN,及其在PyTorch框架下的实现(附原理+代码) - 知乎

图神经网络的不足

•扩展性差,因为训练时需要用到包含所有节点的邻接矩阵,是直推性的(transductive)

•局限于浅层,图神经网络只有两层

•不能作用于有向图

3. 用PyG

图神经网络框架-PyTorch Geometric(PyG)的使用__Old_Summer的博客-CSDN博客_pytorch-geometric

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/8450.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

由浅入深,一起来刷Java高级开发岗面试指南,明年面试必定无忧!

前言 我只想面个CV工程师&#xff0c;面试官偏偏让我挑战造火箭工程师&#xff0c;加上今年这个情况更是前后两男&#xff0c;但再难苟且的生活还要继续&#xff0c;饭碗还是要继续找的。在最近的面试中我一直在总结&#xff0c;每次面试回来也都会复盘&#xff0c;下面是我根…

为啥50岁以后,病就增多了?中老年人想要少生病,该做些什么?

人到中年&#xff0c;生活会有很多变化&#xff0c;很多男性朋友从以前别人口中的小伙子&#xff0c;变成现在家里的顶梁柱&#xff0c;很多以前别人口中的小姑娘&#xff0c;变成现在的贤妻良母&#xff0c;或者拥有自己的一番事业。角色在变化的同时&#xff0c;身体情况也发…

高压电气系统验证

纯电和混合动力汽车中的高压电气系统关乎整车的能耗和安全&#xff0c;需要在部件及整车开发阶段做全面的测试与验证。符合ISO 21498*标准的电压、电流一体式测量模块CSM HV BM系列产品&#xff0c;可以直接串联在整车级别的高压电气线缆中&#xff0c;安全可靠的完成高压电气系…

java面试强基(2)

字符型常量和字符串常量的区别? 形式 : 字符常量是单引号引起的一个字符&#xff0c;字符串常量是双引号引起的 0 个或若干个字符。 含义 : 字符常量相当于一个整型值( ASCII 值),可以参加表达式运算; 字符串常量代表一个地址值(该字符串在内存中存放位置)。 占内存大小 &…

SpringCloud 核心组件Feign【远程调用自定义配置】

目录 1&#xff0c;Feign远程调用 1.1&#xff1a;Feign概述 1.2&#xff1a;Feign替代RestTemplate 1&#xff09;&#xff1a;引入依赖 2&#xff09;&#xff1a;添加注解 3&#xff09;&#xff1a;编写Feign的消费服务&#xff0c;提供服务 4&#xff09;&#xff1a;测…

C. Discrete Acceleration(浮点二分)

Problem - 1408C - Codeforces 题意: 有一条长度为l米的道路。路的起点坐标为0&#xff0c;路的终点坐标为l。 有两辆汽车&#xff0c;第一辆站在路的起点&#xff0c;第二辆站在路的终点。它们将同时开始行驶。第一辆车将从起点开到终点&#xff0c;第二辆车将从终点开到起…

通俗易懂的React事件系统工作原理

前言 React 为我们提供了一套虚拟的事件系统&#xff0c;这套虚拟事件系统是如何工作的&#xff0c;笔者对源码做了一次梳理&#xff0c;整理了下面的文档供大家参考。 在 React事件介绍 中介绍了合成事件对象以及为什么提供合成事件对象&#xff0c;主要原因是因为 React 想…

【附源码】Python计算机毕业设计图书商城购物系统

项目运行 环境配置&#xff1a; Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术&#xff1a; django python Vue 等等组成&#xff0c;B/S模式 pychram管理等等。 环境需要 1.运行环境&#xff1a;最好是python3.7.7&#xff0c;我…

MongoDB学习一:相关概念和单机部署

目录一、MongoDB 应用场景&#xff1a;二、什么时候使用MongoDB&#xff1a;三、MongoDB简介&#xff1a;四、体系结构&#xff1a;五、数据模型&#xff1a;六、MongoDB的特点&#xff1a;七、MongoDB单机部署一、MongoDB 应用场景&#xff1a; 二、什么时候使用MongoDB&#…

对FD描述符(包括inode以及三张表)的一点理解

文件描述符&#xff0c;简单来说是一个从0开始递增的非负整数。 具体来说是linux/unix对文件系统的一种底层抽象&#xff0c;这种抽象是通过三张表来实现的。 这三张表分别是&#xff1a; 1.进程级的文件描述符表&#xff1b;(文件标志位/文件指针) 2.系统级的打开文件描述…

Android Studio App开发之下载管理器DownloadManager中显示、轮询下载进度、利用POST上传文件讲解及实战(附源码)

运行有问题或需要源码请点赞关注收藏后评论区留言~~~ 一、在通知栏显示下载进度 利用GET方式读取数据有很多缺点比如1&#xff1a;无法端点续传 一旦中途失败只能重新获取 2&#xff1a;不是真正意义上的下载操作 无法设置参数 3&#xff1a;下载过程中无法在界面上上展示下…

NAFNet(ECCV 2022)-图像修复论文解读

文章目录解决问题算法背景Simple BaselinePlain Block归一化激活函数Attention机制总结NAFNetSimpleGate替换GELUSCA替换CA总结实验应用RGB图像去噪图像去模糊RAW图像去噪结论论文: 《Simple Baselines for Image Restoration》github: https://github.com/megvii-research/NAF…

同事:这个页面的逻辑没什么能复用的,不抽组件也没什么影响吧?

前言 最近在维护同事的一个项目时&#xff0c;发现有不少单个vue文件一千余行&#xff0c;同一个文件上有倒计时、有输入信息的表单&#xff1b; 当时我就在想&#xff1a;是不是策划经常改需求或者排期紧急&#xff0c;所以没抽组件呢。 沟通过程 以下同事称为阿A 我&#…

【附源码】计算机毕业设计JAVA家庭理财管理系统

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat8.5 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; Springboot mybatis Maven Vue 等等组成&#xff0c;B/…

Java基础—Document类型的变化

Document类型的变化 Document类型的变化中唯一与命名空间无关的方法是importNode()。这个方法的用途是从一个文档中取得一个节点&#xff0c;然后将其导入到另一个文档&#xff0c;使其成为这个文档结构的一部分。需要注意的是&#xff0c;每个节点都有一个ownerDocument属性&…

G1D13-Apt论文阅读fraudgitKGbookrce33-36php环境搭建

一、APT论文 今天终于把6个模型论文和一篇综述读完了&#xff01;&#xff01;&#xff01; 今天主要读了一篇论文写了个总表。发现之前读的论文都忘了&#xff0c;所以 明天要复习一下模型&#xff0c;记录在文档中&#xff0c;并完善模型对比的总表&#xff0c;并且把代码下…

[附源码]java毕业设计基于web的建筑合同管理系统

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

嵌入式FreeRTOS学习九,任务链表的构成,TICK时间中断和任务状态切换调度

一. tskTaskControlBlock 函数结构体 在tskTaskControlBlock 任务控制块结构体中&#xff0c;其中有任务状态链表和事件链表两个链表成员&#xff0c;首先介绍任务状态链表这个结构&#xff0c;这个链表通常用于管理不同状态的任务&#xff1b;通常&#xff0c;操作系统任务有…

CPU、内存、磁盘性能监控

CPU监控 网络由设备、服务器、路由器、交换机和其他网络组件组成。CPU 是网络中所有硬件设备的组成部分。它负责设备的稳定性和性能。企业严重依赖网络&#xff0c;企业硬件的处理能力决定了网络的容量。随着 CPU 功能和硬件的快速发展&#xff0c;组织必须规划其容量并监控其…

成功上岸,刚转行自学Python的小姑娘,每个月入1W+......

我是一名2020年毕业的本科生&#xff0c;大学学的专业是机械设计制造及其自动化。 在大学期间&#xff0c;觉得机械专业实在枯燥无味&#xff0c;没有一点点成就感&#xff0c;每天就是画图纸&#xff0c;测量零件&#xff0c;计算数据&#xff0c;一切都是纸上谈兵。但凡有因…