【LSTM实战】股票走势预测全流程实战(stock predict)

news2025/7/22 5:20:55
  • 任务:利用LSTM模型预测2017年的股票中High的走势,并与真实的数据进行比对。
  • 数据:https://www.kaggle.com/datasets/princeteng/stock-predict

一、import packages|导入第三方库

import pandas as pd
import matplotlib.pyplot as plt
import datetime
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
# read dataset and check it
# 读入数据并且查看
df = pd.read_csv('../input/stock-predict/IBM_2006-01-01_to_2018-01-01.csv', index_col=0)
df.index = list(map(lambda x:datetime.datetime.strptime(x, '%Y-%m-%d'), df.index))
df.head(20)
OpenHighLowCloseVolumeName
2006-01-0382.4582.5580.8182.0611715200IBM
2006-01-0482.2082.5081.3381.959840600IBM
2006-01-0581.4082.9081.0082.507213500IBM
2006-01-0683.9585.0383.4184.958197400IBM
2006-01-0984.1084.2583.3883.736858200IBM
2006-01-1083.1584.1283.1284.075701000IBM
2006-01-1184.3784.8183.4084.175776500IBM
2006-01-1283.8283.9683.4083.574926500IBM
2006-01-1383.0083.4582.5083.176921700IBM
2006-01-1782.8083.1682.5483.008761700IBM
2006-01-1884.0084.7083.5284.4611032800IBM
2006-01-1984.1484.3983.0283.096484000IBM
2006-01-2083.0483.0581.2581.368614500IBM
2006-01-2381.3381.9280.9281.416114100IBM
2006-01-2481.3982.1580.8080.856069000IBM
2006-01-2581.0581.6280.6180.916374300IBM
2006-01-2681.5081.6580.5980.727810200IBM
2006-01-2780.7581.7780.7581.026103400IBM
2006-01-3080.2181.8180.2181.635325100IBM
2006-01-3181.5082.0081.1781.306771600IBM

根据日期的数据列可以大致总结,周六周日有两天不进行股价交易

# the amount of datasets 
# 数据集数量
len(df)
3020

二、data processing|数据处理

def getData(df, column, train_end=-250, days_before=7, return_all=True, generate_index=False):
    series = df[column].copy()
    # split data
    # 划分数据
    train_series, test_series = series[:train_end], series[train_end - days_before:]
    train_data = pd.DataFrame()
        
    # 以七天为一个周期构建数据集和标签
    for i in range(days_before):
        train_data['c%d' % i] = train_series.tolist()[i: -days_before + i]
    # get train labels
    # 获取对应的 label
    train_data['y'] = train_series.tolist()[days_before:]
    # gen index
    # 是否生成 index
    if generate_index:
        train_data.index = train_series.index[n:]
                
    if return_all:
        return train_data, series, df.index.tolist()
    
    return train_data
# build dataloader
# 构建用于模型训练的dataloader
class TrainSet(Dataset):
    def __init__(self, data):
        self.data, self.label = data[:, :-1].float(), data[:, -1].float()

    def __getitem__(self, index):
        return self.data[index], self.label[index]

    def __len__(self):
        return len(self.data)

三、build model|构建模型

# build LSTM model
class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()
        
        self.lstm = nn.LSTM(
            input_size=1,
            hidden_size=64,
            num_layers=1, 
            batch_first=True)
        
        self.out = nn.Sequential(
            nn.Linear(64,1))
        
    def forward(self, x):
        r_out, (h_n, h_c) = self.lstm(x, None)
        out = self.out(r_out[:, -1, :])
        
        return out

四、train model|模型训练

# 数据集建立
train_data, all_series, df_index = getData(df, 'High')

# 获取所有原始数据
all_series = np.array(all_series.tolist())
# 绘制原始数据的图
plt.figure(figsize=(12,8))
plt.plot(df_index, all_series, label='real-data')

# 归一化
train_data_numpy = np.array(train_data)
train_mean = np.mean(train_data_numpy)
train_std  = np.std(train_data_numpy)
train_data_numpy = (train_data_numpy - train_mean) / train_std
train_data_tensor = torch.Tensor(train_data_numpy)

# 创建 dataloader
train_set = TrainSet(train_data_tensor)
train_loader = DataLoader(train_set, batch_size=10, shuffle=True)

在这里插入图片描述

4.1 train model from zero|从头开始训练模型

rnn = LSTM()

if torch.cuda.is_available():
    rnn = rnn.cuda()

# 设置优化器和损失函数
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.0001)
loss_func = nn.MSELoss()

for step in range(100):
    for tx, ty in train_loader:
        
        if torch.cuda.is_available():
            tx = tx.cuda()
            ty = ty.cuda()       
        
        output = rnn(torch.unsqueeze(tx, dim=2))
        loss = loss_func(torch.squeeze(output), ty)
        optimizer.zero_grad() 
        loss.backward()
        optimizer.step()
    if step % 10==0:
        print(step, loss.cpu())
torch.save(rnn, 'model.pkl')
0 tensor(0.0756, grad_fn=<ToCopyBackward0>)
10 tensor(0.0087, grad_fn=<ToCopyBackward0>)
20 tensor(0.0024, grad_fn=<ToCopyBackward0>)
30 tensor(0.0042, grad_fn=<ToCopyBackward0>)
40 tensor(0.0078, grad_fn=<ToCopyBackward0>)
50 tensor(0.0057, grad_fn=<ToCopyBackward0>)
60 tensor(0.0001, grad_fn=<ToCopyBackward0>)
70 tensor(0.0077, grad_fn=<ToCopyBackward0>)
80 tensor(0.0027, grad_fn=<ToCopyBackward0>)
90 tensor(0.0015, grad_fn=<ToCopyBackward0>)

4.2 load model|加载训练好的模型

rnn = LSTM()

rnn = torch.load('model.pkl')
generate_data_train = []
generate_data_test = []

# 测试数据开始的索引
test_start = len(all_series)-250

# 对所有的数据进行相同的归一化
all_series = (all_series - train_mean) / train_std
all_series = torch.Tensor(all_series)

for i in range(7, len(all_series)):
    x = all_series[i - 7:i]
    # 将 x 填充到 (bs, ts, is) 中的 timesteps
    x = torch.unsqueeze(torch.unsqueeze(x, dim=0), dim=2)
    
    if torch.cuda.is_available():
        x = x.cuda()

    y = rnn(x)
    
    if i < test_start:
        generate_data_train.append(torch.squeeze(y.cpu()).detach().numpy() * train_std + train_mean)
    else:
        generate_data_test.append(torch.squeeze(y.cpu()).detach().numpy() * train_std + train_mean)
        
plt.figure(figsize=(12,8))
plt.plot(df_index[7: -250], generate_data_train, 'b', label='generate_train', )
plt.plot(df_index[-250:], generate_data_test, 'k', label='generate_test')
plt.plot(df_index, all_series.clone().numpy()* train_std + train_mean, 'r', label='real_data')
plt.legend()
plt.show()

在这里插入图片描述

五、test model|测试模型

DAYS_BEFORE=7
TRAIN_END=-250

plt.figure(figsize=(10,16))

plt.subplot(2,1,1)
plt.plot(df_index[100 + DAYS_BEFORE: 130 + DAYS_BEFORE], generate_data_train[100: 130], 'b', label='generate_train')
plt.plot(df_index[100 + DAYS_BEFORE: 130 + DAYS_BEFORE], (all_series.clone().numpy()* train_std + train_mean)[100 + DAYS_BEFORE: 130 + DAYS_BEFORE], 'r', label='real_data')
plt.legend()

plt.subplot(2,1,2)
plt.plot(df_index[TRAIN_END + 5: TRAIN_END + 230], generate_data_test[5:230], 'k', label='generate_test')
plt.plot(df_index[TRAIN_END + 5: TRAIN_END + 230], (all_series.clone().numpy()* train_std + train_mean)[TRAIN_END + 5: TRAIN_END + 230], 'r', label='real_data')
plt.legend()
plt.show()

在这里插入图片描述

第一张图表示训练的模型在train集上的表现,第二张图表示在test上预测的表现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/17043.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

利用ESP32实现蓝牙通信的方法

​大家好&#xff0c;我是ST! 上次给大家分享了如何使用ESP32实现UDP通信&#xff0c;今天跟大家聊聊如何使用ESP32实现蓝牙通信。 目录 一、蓝牙简介 二、miropython有关蓝牙的实现方法 三、我的实验代码 四、手机调试APP 一、蓝牙简介 蓝牙是一种无线通讯技术&#xff…

Linux篇【5】:Linux 进程概念(三)

目录 四、进程状态 4.1、各个操作系统下的进程状态&#xff1a; 4.1.1、进程的运行态&#xff1a; 4.1.2、进程的终止态(退出态)&#xff1a; 4.1.3、进程的阻塞态&#xff1a; 4.1.4、进程的挂起态&#xff1a; 4.2、Linux 操作系统下的进程状态&#xff1a; 四、进…

30、Java高级特性——Java API、枚举、包装类、装箱和拆箱

目录 课前先导&#xff1a; 一、Java API 1、API 2、Java API 3、Java API常用包 二、枚举类型 1、枚举 2、枚举类 3、代码演示 3.1 创建枚举类 3.2 创建测试类 4、MyEclipse创建枚举类的快捷方式 三、包装类 1、八大基本数据类型包装类 2、包装类中的构造方…

Java并发编程之可见性分析 volatile

可见性 对于什么是可见性&#xff0c;比较官方的解释就是&#xff1a;一个线程对共享变量的修改&#xff0c;另一个线程能够立刻看到。 说的直白些&#xff0c;就是两个线程共享一个变量&#xff0c;无论哪一个线程修改了这个变量&#xff0c;则另外的一个线程都能够看到上一…

电脑可以通过蓝牙发送文件吗?电脑蓝牙怎么发送文件

蓝牙&#xff08;bluetooth&#xff09;是一种支持设备短距离通信的无线电技术。能在包括移动电话、PDA、无线耳机、笔记本电脑、相关外设等众多设备之间进行无线信息交换。蓝牙技术让数据传输变得更加迅速高效&#xff0c;为无线通信拓宽道路。随着蓝牙技术的发展&#xff0c;…

甘露糖-聚乙二醇-羧酸|mannose-PEG-COOH|羧酸-PEG-甘露糖

甘露糖-聚乙二醇-羧酸|mannose-PEG-COOH|羧酸-PEG-甘露糖 首先合成了二,三分支的甘露糖簇分子.甘露糖经烯丙 苷化,乙酰基保护后,将其烯丙基的双键氧化得到带有羧基连接臂的甘露糖衍生物,然后再分别与1,6-己二胺和三(2-氨乙基)胺进行缩合反应,后脱掉保 护基,得到二分枝甘露糖簇…

Azide-PEG-Thiol,N3-PEG-SH,叠氮-聚乙二醇-巯基可用来制备金纳米颗粒

1、名称 英文&#xff1a;Azide-PEG-Thiol&#xff0c;N3-PEG-SH 中文&#xff1a;叠氮-聚乙二醇-巯基 2、CAS编号&#xff1a;N/A 3、所属分类&#xff1a;Azide PEG Thiol PEG 4、分子量&#xff1a;可定制&#xff0c;5k N3-PEG-SH、20k 叠氮-聚乙二醇-巯基、10k N3-PE…

嵌入式分享合集105

一、智能灯光控制系统&#xff08;基于stm32&#xff09; 带你走进物联网的世界说一个整天方案哦 这次是基于stm32的 当然你可以用esp “智能光照灯”使用STM32作为系统的MCU&#xff0c;由于单片机IO口驱动电流过小&#xff0c;搭配三极管放大电流&#xff0c;从而满足光照强…

全网监控 nginx 部署 zabbix6.0

Zabbix监控 文章目录Zabbix监控一、zabbix6.0部署1、部署zabbix 6.0版本&#xff08;nginxphpzabbix&#xff09;1、nginx配置2、php配置3、mariadb配置二、zabbix配置1、zabbix配置 &#xff08;6.0&#xff09;1、源码安装2、zabbix rpm2、zabbix(5.0安装) -- 补充3、故障汇总…

【Linux】翻山越岭——进程地址空间

文章目录一、是什么写时拷贝二、为什么三、怎么做区域划分和调整一、是什么 回顾我们学习C/C时的地址空间&#xff1a; 有了这个基本框架&#xff0c;我们对于语言的学习更加易于理解&#xff0c;但是地址空间究竟是什么❓我们对其并不了解&#xff0c;是不是内存呢&#xff1…

【创建微服务】创建微服务并使用人人开源代码生成器生成基本代码

创建项目微服务 —— 添加模块 添加依赖 使用 人人开源代码生成器 快速生成 crud 代码 —— https://gitee.com/renrenio 下载导入人人开源项目后&#xff0c;修改 application.yml 文件下的数据库连接配置&#xff1a; 2. 修改 generator.properties 配置文件下的 主路径、包…

CC1101RGPR射频收发器 Low-Power Sub-1GHz 射频收发器

CC1101RGPR射频收发器 Low-Power Sub-1GHz 射频收发器 CC1101RGPR是一种低成本的 sub-1 GHz 收发器专为超低功耗无线应用而设计。该电路主要用于ISM&#xff08;工业、科学和医疗&#xff09;和SRD&#xff08;短程设备&#xff09;频段315、433、868 和 915 MHz&#xff0c;但…

【891. 子序列宽度之和】

来源&#xff1a;力扣&#xff08;LeetCode&#xff09; 描述&#xff1a; 一个序列的 宽度 定义为该序列中最大元素和最小元素的差值。 给你一个整数数组 nums &#xff0c;返回 nums 的所有非空 子序列 的 宽度之和 。由于答案可能非常大&#xff0c;请返回对 109 7 取余 …

UNIAPP实战项目笔记40 设置和地址的页面布局

UNIAPP实战项目笔记40 设置和地址的页面布局 my-config.vue 设置页面布局 具体图片自己替换哈&#xff0c;随便找了个图片的做示例 代码 my-config.vue 页面部分 <template><view class"my-config"><view class"config-item" tap"go…

精益项目管理的流程

我们生活在一个企业家的世界&#xff0c;您可能有许多自己的想法等待实现&#xff0c;但想法在现实中实现是昂贵的。问题是您如何才能获得最大的收益&#xff1f;CEO和管理者如何在追逐梦想和实现目标的同时节省资金&#xff1f;了解初创公司如何进行精益项目管理&#xff0c;它…

第一个汇编程序

第一个汇编程序 文章目录第一个汇编程序1.汇编模拟程序&#xff1a;DOSBox使用2.汇编程序从写出到执行的过程3.程序执行过程跟踪1.汇编模拟程序&#xff1a;DOSBox使用 BOSBox软件常用基本语法&#xff1a; mount c: d:\masn ;挂载磁盘,挂载后用c:切换为C盘才能用debug等工具…

【Java面试八股文宝典之基础篇】备战2023 查缺补漏 你越早准备 越早成功!!!——Day09

大家好&#xff0c;我是陶然同学&#xff0c;软件工程大三明年实习。认识我的朋友们知道&#xff0c;我是科班出身&#xff0c;学的还行&#xff0c;但是对面试掌握不够&#xff0c;所以我将用这100多天更新Java面试题&#x1f643;&#x1f643;。 不敢苟同&#xff0c;相信大…

uni-app入门:WXML列表渲染与条件渲染

1.列表渲染 1.1wx:for 1.2wx:key 2.条件渲染 2.1wx:if 2.2 hidden 正文 WXML全称&#xff1a;wexin markup language,微信标签语言&#xff0c;可以理解为web中的html&#xff0c;今天来讲一下列表渲染&#xff0c;通过几个小案例掌…

艾美捷高纯度 Cholesterol胆固醇相关介绍

胆固醇在体内有着广泛的生理作用&#xff0c;但当其过量时便会导致高胆固醇血症&#xff0c;对机体产生不利的影响。现代研究已发现&#xff0c;动脉粥样硬化、静脉血栓形成与胆石症与高胆固醇血症有密切的相关性。 如果是单纯的胆固醇高则饮食调节是最好的办法&#xff0c;如果…

机器人虚拟仿真工作站考试

总共三个步骤&#xff1a; 创建工作台、加工零件、机器人臂 &#xff01;&#xff01;&#xff01;&#xff01;&#xff01; 一、加工零件的创建 1、先打开sw软件&#xff0c;然后点击零件、创建进入到该软件内&#xff1a; 2、点击前视基础面&#xff08;点击后按esc&#x…