【PyTorch】线性回归

news2025/6/7 23:31:09

文章目录

  • 1. 代码实现
    • 1.1 一元线性回归模型的训练
  • 2. 代码解读
    • 2.1. tensorboardX
      • 2.1.1. tensorboardX的安装
      • 2.1.2. tensorboardX的使用

1. 代码实现

波士顿房价数据集下载

1.1 一元线性回归模型的训练

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, random_split
from tensorboardX import SummaryWriter

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 2
num_epochs = 200

writer = SummaryWriter()

model = nn.Linear(1, 1).to(device)
nn.init.normal_(model.weight, mean=0, std=0.01)
nn.init.constant_(model.bias, 0)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)

data = np.load('dataset/boston_housing/boston_housing.npz')
X = torch.tensor(data['x'][:, 0].reshape(-1, len(model.weight)), dtype=torch.float, device=device)
y = torch.tensor(data['y'].reshape(-1, 1), dtype=torch.float, device=device)
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

for epoch in range(num_epochs):
    for _X, _y in dataloader:
        _X, _y = _X.to(device), _y.to(device)
        loss = criterion(model(_X), _y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss = criterion(model(X), y)
    torch.save(model.state_dict(), 'model/linearRegression.pt')
    model.load_state_dict(torch.load('model/linearRegression.pt'))
    writer.add_scalar('Loss/train', loss, epoch)
    writer.add_scalar('W/train', model.weight, epoch)
    writer.add_scalar('b/train', model.bias, epoch)
writer.close()

2. 代码解读

2.1. tensorboardX

tensorboardX是一种能将训练过程可视化的工具

2.1.1. tensorboardX的安装

安装命令:

pip install tensorboardX

VSCode集成了TensorBoard支持,不过事先要安装torch-tb-profiler,安装命令:

pip install torch-tb-profiler

安装完成后,在Python源文件中tensorboardX模块导入处,点击“启动TensorBoard会话”按钮,然后选择运行事件所在目录,默认选择当前目录即可,tensorboard会自动在当前目录查找运行事件,由此即可启动TensorBoard。
启动TensorBoard会话
logdir
此外,也可以通过以下命令在浏览器查看tensorboard可视化结果:

# logdir为运行事件所在目录
> tensorboard logdir=runs
TensorFlow installation not found - running with reduced feature set.
I1202 20:37:50.824767 15412 plugin.py:429] Monitor runs begin
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.14.0 at http://localhost:6006/ (Press CTRL+C to quit)
# 手动打开命令输出提供的本地服务器地址,如http://localhost:6006/

2.1.2. tensorboardX的使用

  • 直接创建对象
from tensorboardX import SummaryWriter
writer = SummaryWriter()
# writer.add_scalar():添加监控变量
writer.close()
  • 使用上下文管理器
from tensorboardX import SummaryWriter
with SummaryWriter() as writer:
	# writer.add_scalar():添加监控变量

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1278612.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[计算机网络] 高手常用的几个抓包工具(上)

文章目录 高手常用的抓包工具一览什么是抓包工具优秀抓包工具WiresharkFiddlerTcpdumpCharles 高手常用的抓包工具一览 什么是抓包工具 抓包工具是一种可以捕获、分析和修改网络流量的软件。它可以帮助您进行网络调试、性能测试、安全审计等任务。 抓包工具可以实时地显示网…

使用trigger-forward跨流水线传递参数

参考文档:https://docs.gitlab.com/ee/ci/yaml/#triggerforward 今天给大家介绍一个gitlab CI/CD的关键字 - forward,该关键字是一个比较偏的功能,但同时也是一个很实用的功能,我们通过在gitlab的ci文件中使用forward关键字&#…

Android HCI日志分析案例1

案例1--蓝牙扫描设备过程分析 应用层发起搜索蓝牙设备,Android 官方提供的蓝牙扫描方式有三种,分别如下: BluetoothAdapter.startDiscovery(); //可以扫描经典蓝牙和BLE两种。BluetoothAdapter.startLeScan();//扫描低功耗蓝牙,…

深入理解同源限制:网络安全的守护者(上)

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…

超简单的node脚本,将xlsx文件转化为json

开发场景,在一个官网中,官网的设计非常简单,就是一个纯静态的页面,全网站仅一个地方调一下接口,发一下用户填写的信息到运营同学的邮箱,这些数据不会记录在数据库,我需要做一个这样的下拉框。 但…

python使用记录

1、VSCode添加多个python解释器 只需要将对应的python.exe的目录,添加到系统环境变量的Path中即可,VSCode会自动识别及添加 2、pip 使用 pip常用命令和一些坑 查看已安装库的版本号 pip show 库名称 通过git 仓库安装第三方库 pip install git仓库地…

SQL Server 2016(分离和附加数据库)

1、实验环境。 基于上一个实验《SQL Server(创建数据库)》 2、需求描述。 class数据库的数据文件和事务日志文件都位于C:\db_class目录下。现在需要把class数据库的数据文件和事务日志文件分开存放,数据文件class.mdf存放于原位置&#xff0…

DQN原理及PyTorch实现【强化学习】

NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 欢迎来到我们的强化学习系列的第三部分。 在上两篇博客中,我们介绍了强化学习中…

【Linux】命令行参数

文章目录 前言一、C语言main函数的参数二、环境变量总结 前言 我们在Linux命令行输入命令的时候,一般都会跟上一些参数选项,比如l命令,ls -a -l。以前我总是觉得这是理所当然的,没深究其本质究竟是什么,今天才终于知道…

高压配电室智能运维

高压配电室智能运维是指通过运用先进的物联网、大数据、云计算等技术,对高压配电室进行智能化、远程化的运行维护,实现高压配电室的安全、高效、经济运行。以下是高压配电室智能运维的主要功能和优势: 实时监测:通过传感器和监测设…

分页助手入门以及小bug,报sql语法错误

导入坐标 5版本以上的分页助手 可以不用手动指定数据库语言&#xff0c;它会自动识别 <dependency> <groupId>com.github.pagehelper</groupId> <artifactId>pagehelper</artifactId> <version>5.3.2</version> </dependency&g…

vue中中的动画组件使用及如何在vue中使用animate.css

“< Transition >” 是一个内置组件&#xff0c;这意味着它在任意别的组件中都可以被使用&#xff0c;无需注册。它可以将进入和离开动画应用到通过默认插槽传递给它的元素或组件上。进入或离开可以由以下的条件之一触发&#xff1a; 由 v-if 所触发的切换由 v-show 所触…

OpenCV技术应用(6)— 暖色滤镜和冷色滤镜

前言&#xff1a;Hello大家好&#xff0c;我是小哥谈。本节课就手把手教大家如何将一幅图像转化成暖色滤镜和冷色滤镜&#xff0c;希望大家学习之后能够有所收获~&#xff01;&#x1f308; 目录 &#x1f680;1.技术介绍 &#x1f680;2.暖色滤镜 &#x1f680;3.冷色滤…

JVM运行时数据区域

文章目录 内存结构程序计数器&#xff08;寄存器&#xff09;虚拟机栈局部变量表两类异常状况 线程运行诊断 本地方法栈堆方法区运行时常量池串池&#xff08;StringTable&#xff09;字符串的拼接串池的位置StringTable垃圾回收StringTable性能调优 直接内存 内存结构 程序计…

关于标准库中的vector - (涉及迭代器失效,深浅拷贝,构造函数,内置类型构造函数,匿名对象)

目录 关于vector vector中的常见接口 vector常见接口的实现 迭代器失效 关于深浅拷贝 关于vector 关于vector的文档介绍 1. vector是表示可变大小数组的序列容器。 2. 就像数组一样&#xff0c;vector也采用的连续存储空间来存储元素。也就是意味着可以采用下标对vector的元…

【ArcGIS Pro微课1000例】0040:ArcGIS Pro创建北极点、南极点

文章目录 一、创建北极点图层二、创建北极点三、不同投影系下北极点的位置一、创建北极点图层 选择一个数据库,在上面右键→新建→要素类。 输入名称:北极点。 空间参考:WGS 1984 点击创建。 二、创建北极点 在编辑选项卡下,点击【创建】。 在创建要素窗口中,点击北极点…

docker容器中创建非root用户

简介 用 docker 也有一段时间了&#xff0c;一直在 docker 容器中使用 root 用户肆意操作。直到部署 stable diffusion webui 我才发现无法使用 root 用户运行它&#xff0c;于是才幡然醒悟&#xff1a;是时候搞个非 root 用户了。 我使用的 docker 镜像文件是 centos:centos…

FL Studio Producer Edition21.0.3中文版安装详解(附下载链接)

fl studio 21中文版是Image-Line公司继20版本之后更新的水果音乐制作软件&#xff0c;很多用户不太理解&#xff0c;为什么新版本不叫fl studio 21或fl studio2024&#xff0c;非得直接跳到21.2版本&#xff0c;其实该版本是为了纪念该公司22周年&#xff0c;所以该版本也是推出…

(C++)盛水最多的容器--双指针法

个人主页&#xff1a;Lei宝啊 愿所有美好如期而遇 力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台备战技术面试&#xff1f;力扣提供海量技术面试资源&#xff0c;帮助你高效提升编程技能&#xff0c;轻松拿下世界 IT 名企 Dream Offer。https://le…

pgsql分别获取日期中的年、月、日,并处理前台展示有小数点的情况

使用extract()函数 select extract(YEAR from 需要处理的日期字段) from tablename; --获取年份 select extract(MONTH from 需要处理的日期字段) from tablename; --获取月份 select extract(DAY from 需要处理的日期字段) from tablename; --获取日 实际应用&#xff1a;…