Auto-Coder使用GPT-4o完成:在用TabPFN这个模型构建一个预测未来3天涨跌的分类任务

news2025/6/12 18:38:02

通过akshare库,获取股票数据,并生成TabPFN这个模型 可以识别、处理的格式,写一个完整的预处理示例,并构建一个预测未来 3 天股价涨跌的分类任务

用TabPFN这个模型构建一个预测未来 3 天股价涨跌的分类任务,进行预测并输出准确率

代码

import os
import sys
import pandas as pd
import akshare as ak
# try:
#     import akshare as ak
# except ImportError:
#     print("请先运行: pip install akshare")
#     sys.exit(1)

def fetch_stock_data(stock_code="000001", start_date="20150101", end_date=None):
    """
    获取股票历史行情数据,兼容不同akshare接口
    """
    print(f"DEBUG: akshare version: {ak.__version__}")
    print(f"DEBUG: fetch_stock_data params: symbol={stock_code}, start_date={start_date}, end_date={end_date}")
    try:
        # 优先尝试 stock_zh_a_daily
        df = ak.stock_zh_a_daily(symbol=stock_code, adjust="qfq")
        print("DEBUG: 使用 ak.stock_zh_a_daily 成功")
        print("DEBUG: Columns before rename:", df.columns.tolist())
        print("DEBUG: Head before rename:\n", df.head())
        # 若有 start_date/end_date,筛选
        if "date" in df.columns:
            df["date"] = pd.to_datetime(df["date"])
            if start_date:
                df = df[df["date"] >= pd.to_datetime(start_date)]
            if end_date:
                df = df[df["date"] <= pd.to_datetime(end_date)]
            df = df.sort_values("date").reset_index(drop=True)
    except Exception as e:
        print("ERROR: ak.stock_zh_a_daily() 调用异常,尝试 fallback 原接口")
        print(f"Exception: {e}")
        try:
            df = ak.stock_zh_a_hist(symbol=stock_code, period="daily", start_date=start_date, end_date=end_date, adjust="qfq")
            print("DEBUG: 使用 ak.stock_zh_a_hist 成功")
            print("DEBUG: Columns before rename:", df.columns.tolist())
            print("DEBUG: Head before rename:\n", df.head())
        except Exception as e2:
            print("ERROR: Exception occurred while fetching stock data by both interfaces!")
            print(f"Exception: {e2}")
            return pd.DataFrame()
    if df.empty:
        print("ERROR: Fetched DataFrame is empty! Check stock code, date range, or network/API issues.")
        return df
    # 兼容列名
    rename_map = {
        "日期": "date",
        "开盘": "open",
        "收盘": "close",
        "最高": "high",
        "最低": "low",
        "成交量": "volume",
        "成交额": "amount",
        "振幅": "amplitude",
        "涨跌幅": "pct_chg",
        "涨跌额": "chg",
        "换手率": "turnover"
    }
    for k in list(rename_map.keys()):
        if k not in df.columns:
            rename_map.pop(k)
    df = df.rename(columns=rename_map)
    if "date" in df.columns:
        df["date"] = pd.to_datetime(df["date"])
        df = df.sort_values("date").reset_index(drop=True)
    # 自动补充pct_chg列(涨跌幅),百分比格式
    if "pct_chg" not in df.columns and "close" in df.columns:
        df["pct_chg"] = df["close"].pct_change() * 100
    return df

def create_features_and_labels(df, n_past=5, n_future=3):
    """
    构造特征和标签,标签为未来3日涨跌(1=涨,0=跌/平)
    """
    feats = []
    labels = []
    for idx in range(n_past, len(df) - n_future):
        past_slice = df.iloc[idx-n_past:idx]
        # 特征: 过去n_past日的收盘价、涨跌幅、成交量
        feature = []
        feature += list(past_slice["close"].values)
        feature += list(past_slice["pct_chg"].values)
        feature += list(past_slice["volume"].values)
        # 未来n_future日的收盘价均值
        future_close_mean = df.iloc[idx:idx+n_future]["close"].mean()
        curr_close = df.iloc[idx-1]["close"]
        # 涨跌标签: 未来3日均值 > 当前收盘价 => 1,否则0
        label = 1 if future_close_mean > curr_close else 0
        feats.append(feature)
        labels.append(label)
    feats_df = pd.DataFrame(feats, columns=[
        f"close_t-{i}" for i in range(n_past,0,-1)
    ] + [
        f"pct_chg_t-{i}" for i in range(n_past,0,-1)
    ] + [
        f"volume_t-{i}" for i in range(n_past,0,-1)
    ])
    feats_df["label"] = labels
    return feats_df

def save_for_tabpfn(df, out_csv):
    """
    保存为TabPFN模型可读取的csv格式
    """
    df.to_csv(out_csv, index=False)
    print(f"已保存至: {out_csv}")
    print(df.head())

def main():
    # 拉取A股代码表,打印前10条,辅助判断symbol格式
    print("尝试拉取A股代码表,辅助symbol格式判断...")
    try:
        code_df = ak.stock_info_a_code_name()
        print("A股代码表前10条:")
        print(code_df.head(10))
        print("平安银行相关行:")
        print(code_df[code_df["code"].str.contains("000001")])
    except Exception as e:
        print(f"拉取A股代码表失败: {e}")

    stock_code = "sz000001"  # 平安银行
    start_date = "20150101"
    print("正在获取股票数据...")
    df = fetch_stock_data(stock_code, start_date)
    print("正在生成特征与标签...")
    processed = create_features_and_labels(df, n_past=5, n_future=3)
    # out_csv = os.path.join(os.path.dirname(__file__), f"{stock_code}_tabpfn.csv")
    out_csv =  f"{stock_code}_tabpfn.csv"
    print("正在保存为TabPFN格式...")
    save_for_tabpfn(processed, out_csv)
    print("预处理完成。")

if __name__ == "__main__":
    main()

预测

import os
# 临时绕过 /dev/null 权限问题
os.devnull = "/tmp/null"
if not os.path.exists("/tmp/null"):
    with open("/tmp/null", "w") as f:
        pass

import sys
import numpy as np
import pandas as pd

# 自动安装tabpfn(如未安装)
try:
    from tabpfn import TabPFNClassifier
except ImportError:
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "tabpfn"])
    from tabpfn import TabPFNClassifier

# 读取数据
# DATA_PATH = os.path.join(os.path.dirname(__file__), "sz000001_tabpfn.csv")
DATA_PATH = "sz000001_tabpfn.csv"
data = pd.read_csv(DATA_PATH, header=None)
X = data.iloc[:, :-1].values
y = data.iloc[:, -1].values

# 按时间顺序划分(前80%训练,后20%测试)
split_idx = int(0.8 * len(X))
X_train, X_test = X[:split_idx], X[split_idx:]
y_train, y_test = y[:split_idx], y[split_idx:]

# TabPFN训练与预测
# clf = TabPFNClassifier(device='cpu')
clf = TabPFNClassifier(device='cuda')
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
accuracy = np.mean(y_pred == y_test)

print(f"Test Accuracy: {accuracy:.4f}")

最后生成的准确率:

Test Accuracy: 0.4812

这个准确率低于50%,反而可能证明程序是对的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2407751.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

dedecms 织梦自定义表单留言增加ajax验证码功能

增加ajax功能模块&#xff0c;用户不点击提交按钮&#xff0c;只要输入框失去焦点&#xff0c;就会提前提示验证码是否正确。 一&#xff0c;模板上增加验证码 <input name"vdcode"id"vdcode" placeholder"请输入验证码" type"text&quo…

抖音增长新引擎:品融电商,一站式全案代运营领跑者

抖音增长新引擎&#xff1a;品融电商&#xff0c;一站式全案代运营领跑者 在抖音这个日活超7亿的流量汪洋中&#xff0c;品牌如何破浪前行&#xff1f;自建团队成本高、效果难控&#xff1b;碎片化运营又难成合力——这正是许多企业面临的增长困局。品融电商以「抖音全案代运营…

2.Vue编写一个app

1.src中重要的组成 1.1main.ts // 引入createApp用于创建应用 import { createApp } from "vue"; // 引用App根组件 import App from ./App.vue;createApp(App).mount(#app)1.2 App.vue 其中要写三种标签 <template> <!--html--> </template>…

全球首个30米分辨率湿地数据集(2000—2022)

数据简介 今天我们分享的数据是全球30米分辨率湿地数据集&#xff0c;包含8种湿地亚类&#xff0c;该数据以0.5X0.5的瓦片存储&#xff0c;我们整理了所有属于中国的瓦片名称与其对应省份&#xff0c;方便大家研究使用。 该数据集作为全球首个30米分辨率、覆盖2000–2022年时间…

定时器任务——若依源码分析

分析util包下面的工具类schedule utils&#xff1a; ScheduleUtils 是若依中用于与 Quartz 框架交互的工具类&#xff0c;封装了定时任务的 创建、更新、暂停、删除等核心逻辑。 createScheduleJob createScheduleJob 用于将任务注册到 Quartz&#xff0c;先构建任务的 JobD…

STM32标准库-DMA直接存储器存取

文章目录 一、DMA1.1简介1.2存储器映像1.3DMA框图1.4DMA基本结构1.5DMA请求1.6数据宽度与对齐1.7数据转运DMA1.8ADC扫描模式DMA 二、数据转运DMA2.1接线图2.2代码2.3相关API 一、DMA 1.1简介 DMA&#xff08;Direct Memory Access&#xff09;直接存储器存取 DMA可以提供外设…

376. Wiggle Subsequence

376. Wiggle Subsequence 代码 class Solution { public:int wiggleMaxLength(vector<int>& nums) {int n nums.size();int res 1;int prediff 0;int curdiff 0;for(int i 0;i < n-1;i){curdiff nums[i1] - nums[i];if( (prediff > 0 && curdif…

《用户共鸣指数(E)驱动品牌大模型种草:如何抢占大模型搜索结果情感高地》

在注意力分散、内容高度同质化的时代&#xff0c;情感连接已成为品牌破圈的关键通道。我们在服务大量品牌客户的过程中发现&#xff0c;消费者对内容的“有感”程度&#xff0c;正日益成为影响品牌传播效率与转化率的核心变量。在生成式AI驱动的内容生成与推荐环境中&#xff0…

学校招生小程序源码介绍

基于ThinkPHPFastAdminUniApp开发的学校招生小程序源码&#xff0c;专为学校招生场景量身打造&#xff0c;功能实用且操作便捷。 从技术架构来看&#xff0c;ThinkPHP提供稳定可靠的后台服务&#xff0c;FastAdmin加速开发流程&#xff0c;UniApp则保障小程序在多端有良好的兼…

家政维修平台实战20:权限设计

目录 1 获取工人信息2 搭建工人入口3 权限判断总结 目前我们已经搭建好了基础的用户体系&#xff0c;主要是分成几个表&#xff0c;用户表我们是记录用户的基础信息&#xff0c;包括手机、昵称、头像。而工人和员工各有各的表。那么就有一个问题&#xff0c;不同的角色&#xf…

最新SpringBoot+SpringCloud+Nacos微服务框架分享

文章目录 前言一、服务规划二、架构核心1.cloud的pom2.gateway的异常handler3.gateway的filter4、admin的pom5、admin的登录核心 三、code-helper分享总结 前言 最近有个活蛮赶的&#xff0c;根据Excel列的需求预估的工时直接打骨折&#xff0c;不要问我为什么&#xff0c;主要…

基于当前项目通过npm包形式暴露公共组件

1.package.sjon文件配置 其中xh-flowable就是暴露出去的npm包名 2.创建tpyes文件夹&#xff0c;并新增内容 3.创建package文件夹

MMaDA: Multimodal Large Diffusion Language Models

CODE &#xff1a; https://github.com/Gen-Verse/MMaDA Abstract 我们介绍了一种新型的多模态扩散基础模型MMaDA&#xff0c;它被设计用于在文本推理、多模态理解和文本到图像生成等不同领域实现卓越的性能。该方法的特点是三个关键创新:(i) MMaDA采用统一的扩散架构&#xf…

智能在线客服平台:数字化时代企业连接用户的 AI 中枢

随着互联网技术的飞速发展&#xff0c;消费者期望能够随时随地与企业进行交流。在线客服平台作为连接企业与客户的重要桥梁&#xff0c;不仅优化了客户体验&#xff0c;还提升了企业的服务效率和市场竞争力。本文将探讨在线客服平台的重要性、技术进展、实际应用&#xff0c;并…

【CSS position 属性】static、relative、fixed、absolute 、sticky详细介绍,多层嵌套定位示例

文章目录 ★ position 的五种类型及基本用法 ★ 一、position 属性概述 二、position 的五种类型详解(初学者版) 1. static(默认值) 2. relative(相对定位) 3. absolute(绝对定位) 4. fixed(固定定位) 5. sticky(粘性定位) 三、定位元素的层级关系(z-i…

[ICLR 2022]How Much Can CLIP Benefit Vision-and-Language Tasks?

论文网址&#xff1a;pdf 英文是纯手打的&#xff01;论文原文的summarizing and paraphrasing。可能会出现难以避免的拼写错误和语法错误&#xff0c;若有发现欢迎评论指正&#xff01;文章偏向于笔记&#xff0c;谨慎食用 目录 1. 心得 2. 论文逐段精读 2.1. Abstract 2…

转转集团旗下首家二手多品类循环仓店“超级转转”开业

6月9日&#xff0c;国内领先的循环经济企业转转集团旗下首家二手多品类循环仓店“超级转转”正式开业。 转转集团创始人兼CEO黄炜、转转循环时尚发起人朱珠、转转集团COO兼红布林CEO胡伟琨、王府井集团副总裁祝捷等出席了开业剪彩仪式。 据「TMT星球」了解&#xff0c;“超级…

【快手拥抱开源】通过快手团队开源的 KwaiCoder-AutoThink-preview 解锁大语言模型的潜力

引言&#xff1a; 在人工智能快速发展的浪潮中&#xff0c;快手Kwaipilot团队推出的 KwaiCoder-AutoThink-preview 具有里程碑意义——这是首个公开的AutoThink大语言模型&#xff08;LLM&#xff09;。该模型代表着该领域的重大突破&#xff0c;通过独特方式融合思考与非思考…

什么是库存周转?如何用进销存系统提高库存周转率?

你可能听说过这样一句话&#xff1a; “利润不是赚出来的&#xff0c;是管出来的。” 尤其是在制造业、批发零售、电商这类“货堆成山”的行业&#xff0c;很多企业看着销售不错&#xff0c;账上却没钱、利润也不见了&#xff0c;一翻库存才发现&#xff1a; 一堆卖不动的旧货…

el-switch文字内置

el-switch文字内置 效果 vue <div style"color:#ffffff;font-size:14px;float:left;margin-bottom:5px;margin-right:5px;">自动加载</div> <el-switch v-model"value" active-color"#3E99FB" inactive-color"#DCDFE6"…