基于PyTorch神经网络进行温度预测——基于jupyter实现

news2025/5/31 17:58:43

导入环境

import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import warnings
warnings.filterwarnings("ignore")
%matplotlib inline

读取文件

### 读取数据文件
features = pd.read_csv('temps.csv')
#看看数据长什么样子
features.head(5)

在这里插入图片描述
其中
数据表中

  • year,moth,day,week分别表示的具体的时间
  • temp_2:前天的最高温度值
  • temp_1:昨天的最高温度值
  • average:在历史中,每年这一天的平均最高温度值
  • actual:这就是我们的标签值了,当天的真实最高温度
  • friend:据说凑热闹

查阅数据维度

print('数据维度:', features.shape)

在这里插入图片描述

时间维度数据进行处理


# 处理时间数据
import datetime

# 分别得到年,月,日
years = features['year']
months = features['month']
days = features['day']

# datetime格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in dates]
查阅数据
data[:,5]

在这里插入图片描述

图像绘制

# 准备画图
# 指定默认风格
plt.style.use('fivethirtyeight')

# 设置布局
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, figsize = (10,10))
fig.autofmt_xdate(rotation = 45)

# 标签值
ax1.plot(dates, features['actual'])
ax1.set_xlabel(''); ax1.set_ylabel('Temperature'); ax1.set_title('Max Temp')

# 昨天
ax2.plot(dates, features['temp_1'])
ax2.set_xlabel(''); ax2.set_ylabel('Temperature'); ax2.set_title('Previous Max Temp')

# 前天
ax3.plot(dates, features['temp_2'])
ax3.set_xlabel('Date'); ax3.set_ylabel('Temperature'); ax3.set_title('Two Days Prior Max Temp')

# 朋友
ax4.plot(dates, features['friend'])
ax4.set_xlabel('Date'); ax4.set_ylabel('Temperature'); ax4.set_title('Friend Estimate')

plt.tight_layout(pad=2)

在这里插入图片描述

独热编码

数据需要独热编码(One-Hot Encoding),许多机器学习算法预期输入是数值型的,并且它们在处理数值型数据时表现更好。
独热编码是一种处理分类数据的方法,特别是在分类数据的各个类别之间没有顺序或等级的情况下。以下是使用独热编码的几个原因:

  1. 避免数值偏见:在很多模型中,如线性模型和神经网络,使用普通的数值标签(如1, 2, 3…)可能导致模型误认为类别之间存在数值上的关系,比如2是1的两倍,这可能会引入模型误解。
  2. 改善模型性能:通过独热编码,模型可以更明确地捕捉到每个类别的独特性,因为每个类别都由一个独立的特征表示,这有助于提高模型的准确性和学习效率。
  3. 扩展特征空间:独热编码可以将分类变量转化为一个更大的固定长度的数值型特征向量,这使得算法能够更容易地在这些扩展的特征空间上进行操作和优化。
# 独热编码
features = pd.get_dummies(features)
features.head(5)

在这里插入图片描述

处理标签

# 标签
labels = np.array(features['actual'])
# 在特征中去掉标签
features= features.drop('actual', axis = 1)
# 名字单独保存一下
feature_list = list(features.columns)
# 转换成合适的格式
features = np.array(features)
features.shape

在这里插入图片描述

机器学习建模

数据标准化

标准化的作用:

  1. 消除量纲影响:在很多数据集中,不同的特征可能具有完全不同的量纲和单位(如公里、千克、百分比等)。未经标准化的数据如果直接用于模型训练,可能会因为量纲的差异而影响模型的性能,使得某些特征的权重过大或过小。

  2. 提高算法表现:很多机器学习算法(尤其是基于距离的算法如K-最近邻、支持向量机等)在处理数据时,会受到特征尺度的影响。通过标准化处理,可以确保每个特征对模型的影响是均衡的,从而提高算法的精确度和效率。

  3. 加速模型收敛:在使用梯度下降等优化算法时,如果数据集的特征尺度差异较大,可能会导致优化过程中步长的不均匀,使得收敛速度变慢。标准化后,由于所有特征都处在相同的尺度上,有助于加快学习算法的收敛速度。

  4. 应对异常值:标准化过程通常包括消除异常值的影响,比如通过将数据缩放到一个固定的范围(如0到1之间),或者通过z-score方法(即减去平均值,除以标准差)来减少某些极端值对整体数据分布的影响。

from sklearn import preprocessing
input_features = preprocessing.StandardScaler().fit_transform(features)
input_features[0]

在这里插入图片描述

torch搭建MLP模型

x = torch.tensor(input_features, dtype = float)

y = torch.tensor(labels, dtype = float)

# 权重参数初始化
weights = torch.randn((14, 128), dtype = float, requires_grad = True) 
biases = torch.randn(128, dtype = float, requires_grad = True) 
weights2 = torch.randn((128, 1), dtype = float, requires_grad = True) 
biases2 = torch.randn(1, dtype = float, requires_grad = True) 

learning_rate = 0.001 
losses = []

for i in range(1000):
    # 计算隐层
    hidden = x.mm(weights) + biases
    # 加入激活函数
    hidden = torch.relu(hidden)
    # 预测结果
    predictions = hidden.mm(weights2) + biases2
    # 通计算损失
    loss = torch.mean((predictions - y) ** 2) 
    losses.append(loss.data.numpy())
    
    # 打印损失值
    if i % 100 == 0:
        print('loss:', loss)
    #返向传播计算
    loss.backward()
    
    #更新参数
    weights.data.add_(- learning_rate * weights.grad.data)  
    biases.data.add_(- learning_rate * biases.grad.data)
    weights2.data.add_(- learning_rate * weights2.grad.data)
    biases2.data.add_(- learning_rate * biases2.grad.data)
    
    # 每次迭代都得记得清空
    weights.grad.data.zero_()
    biases.grad.data.zero_()
    weights2.grad.data.zero_()
    biases2.grad.data.zero_()

在这里插入图片描述

预测结果

predictions.shape

在这里插入图片描述

整体模型

input_size = input_features.shape[1]
hidden_size = 128
output_size = 1
batch_size = 16
my_nn = torch.nn.Sequential(
    torch.nn.Linear(input_size, hidden_size),
    torch.nn.Sigmoid(),
    torch.nn.Linear(hidden_size, output_size),
)
cost = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(my_nn.parameters(), lr = 0.001)

# 训练网络
losses = []
for i in range(1000):
    batch_loss = []
    # MINI-Batch方法来进行训练
    for start in range(0, len(input_features), batch_size):
        end = start + batch_size if start + batch_size < len(input_features) else len(input_features)
        xx = torch.tensor(input_features[start:end], dtype = torch.float, requires_grad = True)
        yy = torch.tensor(labels[start:end], dtype = torch.float, requires_grad = True)
        prediction = my_nn(xx)
        loss = cost(prediction, yy)
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
        batch_loss.append(loss.data.numpy())
    
    # 打印损失
    if i % 100==0:
        losses.append(np.mean(batch_loss))
        print(i, np.mean(batch_loss))

在这里插入图片描述

预测结果

x = torch.tensor(input_features, dtype = torch.float)
predict = my_nn(x).data.numpy()

日期转换

# 转换日期格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in dates]

# 创建一个表格来存日期和其对应的标签数值
true_data = pd.DataFrame(data = {'date': dates, 'actual': labels})

# 同理,再创建一个来存日期和其对应的模型预测值
months = features[:, feature_list.index('month')]
days = features[:, feature_list.index('day')]
years = features[:, feature_list.index('year')]

test_dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]

test_dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in test_dates]

predictions_data = pd.DataFrame(data = {'date': test_dates, 'prediction': predict.reshape(-1)}) 
# 真实值
plt.plot(true_data['date'], true_data['actual'], 'b-', label = 'actual')

# 预测值
plt.plot(predictions_data['date'], predictions_data['prediction'], 'ro', label = 'prediction')
plt.xticks(rotation = '60'); 
plt.legend()

# 图名
plt.xlabel('Date'); plt.ylabel('Maximum Temperature (F)'); plt.title('Actual and Predicted Values');

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1592929.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

蓝桥杯-数组分割

问题描述 小蓝有一个长度为 N 的数组 A 「Ao,A1,…,A~-1]。现在小蓝想要从 A 对应的数组下标所构成的集合I 0,1,2,… N-1 中找出一个子集 民1&#xff0c;那么 民」在I中的补集为Rz。记S∑reR 4&#xff0c;S2∑rERA,&#xff0c;我们要求S、和 S,均为偶数&#xff0c;请问在这…

如何访问远程服务器?

在现代技术时代&#xff0c;随着信息化的快速发展&#xff0c;远程访问服务器已经成为了不可或缺的一种需求。无论是企业还是个人用户&#xff0c;都需要通过远程访问来管理、传输和获取数据。本文将介绍一种名为【天联】的工具&#xff0c;它能够通过私有通道进行远程服务器访…

iptables/ebtables学习笔记

目录 一、前言 二、Netfilter 构成 三、Netfilter 转发框架 四、Netfilter 与 iptables 五、Netfilter 与 ebtables 一、前言 Netfilter 是 Linux 内核的数据包处理框架&#xff0c;由 Rusty Russell 于 1998 年开发&#xff0c; 旨在改进以前的 ipchains&#xff08;Lin…

【排序 贪心】3107. 使数组中位数等于 K 的最少操作数

算法可以发掘本质&#xff0c;如&#xff1a; 一&#xff0c;若干师傅和徒弟互有好感&#xff0c;有好感的师徒可以结对学习。师傅和徒弟都只能参加一个对子。如何让对子最多。 二&#xff0c;有无限多1X2和2X1的骨牌&#xff0c;某个棋盘若干格子坏了&#xff0c;如何在没有坏…

centos编译安装nginx1.24

nginx编译1.24&#xff0c;先下载安装包 机器通外网的话配置nginx的yum源直接yum安装 vim /etc/yum.repos.d/nginx.repo [nginx-stable] namenginx stable repo baseurlhttp://nginx.org/packages/centos/$releasever/$basearch/ gpgcheck1 enabled1 gpgkeyhttps://nginx.org…

前端实现自动获取农历日期:探索JavaScript的跨文化编程

&#x1f31f; 前言 欢迎来到我的技术小宇宙&#xff01;&#x1f30c; 这里不仅是我记录技术点滴的后花园&#xff0c;也是我分享学习心得和项目经验的乐园。&#x1f4da; 无论你是技术小白还是资深大牛&#xff0c;这里总有一些内容能触动你的好奇心。&#x1f50d; &#x…

嵌入式学习52-ARM1

知识零散&#xff1a; 1.flash&#xff1a; nor flash 可被寻地址 …

Matlab与ROS(1/2)---Simulink(二)

0. 简介 在上一章中我们详细介绍了ROS与Matlab链接的基础用法。这一章我们将来学习如何使用Matlab当中的Simulink来完成。Simulink对机器人操作系统(ROS)的支持使我们能够创建与ROS网络一起工作的Simulink模型。ROS是一个通信层&#xff0c;允许机器人系统的不同组件以消息的形…

Linux 使用 ifconfig 报错:Failed to start LSB: Bring up/down networking

一、报错信息 在运行项目时报错数据库连接失败&#xff0c;我就想着检查一下虚拟机是不是 Mysql 服务忘了开&#xff0c;结果远程连接都连接不上虚拟机上的 Linux 了&#xff0c;想着查一下 IP 地址看看&#xff0c;一查就报错了&#xff0c;报错信息&#xff1a; Restarting…

【2024年5月备考新增】《软考真题分章练习(含答案解析) - 18 管理科学-运筹学基础 (2)》

21、某种商品价格 P 变动与某指标 A 的变化具有很强的相关性,指标 A 的增长会导致 P 的降低,反之亦然。指标 A 和价格 P 的相关性系数是()。 A.0.18 B.0 C.0.98 D.-0.83 【答案】D 【解析】A 的增长会导致 B 的降低,反比关系,系数必然是一个负数。正比函数 y=kx,当 k>…

Java中创建多线程的方法

继承Thread类&#xff0c;对该类进行new一个实例&#xff0c;对实例调用start方法&#xff0c;重写run方法。 缺点&#xff1a;单继承&#xff0c;无法继承 public class myThread extends Thread {public static void main(String[] args) {myThread myThread new myThread()…

阐述嵌入式系统的基本组成:硬件层、驱动层、操作系统层和应用层

大家好&#xff0c;今天给大家介绍阐述嵌入式系统的基本组成&#xff1a;硬件层、驱动层、操作系统层和应用层&#xff0c;文章末尾附有分享大家一个资料包&#xff0c;差不多150多G。里面学习内容、面经、项目都比较新也比较全&#xff01;可进群免费领取。 嵌入式系统是一种能…

Redis入门到通关之Hash命令

文章目录 ⛄介绍⛄命令⛄RedisTemplate API❄️❄️添加缓存❄️❄️设置过期时间(单独设置)❄️❄️添加一个Map集合❄️❄️提取所有的小key❄️❄️提取所有的value值❄️❄️根据key提取value值❄️❄️获取所有的键值对集合❄️❄️删除❄️❄️判断Hash中是否含有该值 ⛄…

Linux的内存管理子系统

大家好&#xff0c;今天给大家介绍Linux的内存管理子系统&#xff0c;文章末尾附有分享大家一个资料包&#xff0c;差不多150多G。里面学习内容、面经、项目都比较新也比较全&#xff01;可进群免费领取。 Linux的内存管理子系统是Linux内核中一个非常重要且复杂的子系统&#…

深度解析 Spark(进阶):架构、集群运行机理与核心组件详解

关联阅读博客文章&#xff1a;深度解析SPARK的基本概念 引言&#xff1a; Apache Spark作为一种快速、通用、可扩展的大数据处理引擎&#xff0c;在大数据领域中备受关注和应用。本文将深入探讨Spark的集群运行原理、核心组件、工作原理以及分布式计算模型&#xff0c;带领读者…

b站江科大stm32笔记(持续更新)

b站江科大stm32笔记&#xff08;持续更新&#xff09; 片上资源/外设引脚定义表启动配置推挽开漏oc/od 门漏极/集电极 电阻的上拉下拉输入捕获输入捕获通道主从触发模式输入捕获基本结构PWMI基本结构PWMPSC ARR CRR输入捕获模式测频率TIM_PrescalerConfig()初始化输入捕获测频法…

[C++][算法基础]Dijkstra求最短路径I(稠密图)

给定一个 n 个点 m 条边的有向图&#xff0c;图中可能存在重边和自环&#xff0c;所有边权均为正值。 请你求出 1 号点到 n 号点的最短距离&#xff0c;如果无法从 1 号点走到 n 号点&#xff0c;则输出 −1。 输入格式 第一行包含整数 n 和 m。 接下来 m 行每行包含三个整…

AI图书推荐:如何在课堂上使用ChatGPT 进行教育

ChatGPT是一款强大的新型人工智能&#xff0c;已向公众免费开放。现在&#xff0c;各级别的教师、教授和指导员都能利用这款革命性新技术的力量来提升教育体验。 本书提供了一个易于理解的ChatGPT解释&#xff0c;并且更重要的是&#xff0c;详述了如何在课堂上以多种不同方式…

TQ15EG开发板教程:在MPSOC上运行ADRV9009(vivado2018.3)

首先需要在github上下载两个文件&#xff0c;本例程用到的文件以及最终文件我都会放在网盘里面&#xff0c; 地址放在最后面。在github搜索hdl选择第一个&#xff0c;如下图所示 GitHub网址&#xff1a;https://github.com/analogdevicesinc/hdl/releases 点击releases选择版…

关于ARM的一些问题

一&#xff0c;arm的工作模式有哪些&#xff1f; User&#xff1a;非特权模式 FIQ&#xff1a;高优先级中断进入 IRQ&#xff1a;低优先级中断进入 Supervisor:当复位或软中断指令进入 Abort: 当存取异常时 Undef:当执行未定义指令时会进入这种模式 System:使用和User模式相同…