当LSTM遇上注意力:手把手教你玩转时序预测
Attention-LSTM时序预测单输入单输出 基于注意力机制attention结合长短期记忆网络LSTM时间序列预测 单输入单输出模型 MATLAB版本为2020b及其以上 中文注释清晰非常适合科研小白 评价指标包括:R2、MAE、MSE、RMSE等时序预测总让人头疼试试这个能让模型自动划重点的Attention-LSTM组合拳。咱们今天用MATLAB实战演练就算刚入门的同学也能轻松上手。Attention-LSTM时序预测单输入单输出 基于注意力机制attention结合长短期记忆网络LSTM时间序列预测 单输入单输出模型 MATLAB版本为2020b及其以上 中文注释清晰非常适合科研小白 评价指标包括:R2、MAE、MSE、RMSE等先看数据准备。假设我们有个温度变化序列每30分钟采样一次% 生成示例数据替换成真实数据即可 time 1:720; % 假设是30天的分钟级数据 temperature 20 5*sin(time/24) randn(size(time))*0.5; data num2cell(temperature); % 转换为cell格式方便处理 % 划分训练测试集7:3比例 trainData data(1:500); testData data(501:end);接下来是模型的重头戏——自定义注意力层。这里有个小技巧通过全连接层计算注意力权重classdef attentionLayer nnet.layer.Layer properties Units end methods function layer attentionLayer(units) layer.Units units; layer.Name attention; end function Z predict(layer, X) % X的维度 [features, timesteps, batch] [~, timesteps, batch] size(X); % 注意力权重计算 weights fullyconnect(X, rand(1, size(X,1)), DataFormat,CBT); weights softmax(weights, DataFormat,CB); % 上下文向量生成 context sum(X .* reshape(weights,1,timesteps,batch),2); Z context; % 输出维度 [features, 1, batch] end end end模型搭建就像搭积木。注意这个设计先让LSTM记住长期模式再用注意力找出关键时间点inputSize 1; numHiddenUnits 64; layers [ sequenceInputLayer(inputSize) lstmLayer(numHiddenUnits,OutputMode,sequence) attentionLayer(32) % 自定义注意力层 fullyConnectedLayer(1) regressionLayer]; options trainingOptions(adam, ... MaxEpochs,200, ... MiniBatchSize,32, ... Plots,training-progress);训练过程可能会有点过山车别慌——这是模型在努力学习数据的节奏net trainNetwork(trainData(1:end-1), trainData(2:end), layers, options);预测阶段要注意维度对齐。这里有个避坑指南记得用predictAndUpdateState处理连续预测% 单步预测 net predictAndUpdateState(net, trainData); [net, pred] predict(net, testData(1:end-1), SequenceLength,1); % 可视化对比 plot([cell2mat(testData(2:end)); pred], LineWidth,1.5) legend(真实值,预测值)最后上硬指标用这些公式给模型打分% 评价指标计算 y_true cell2mat(testData(2:end)); y_pred cell2mat(pred); mae mean(abs(y_true - y_pred)); mse mean((y_true - y_pred).^2); rmse sqrt(mse); r2 1 - sum((y_true - y_pred).^2)/sum((y_true - mean(y_true)).^2); disp([MAE: num2str(mae) MSE: num2str(mse)]) disp([RMSE: num2str(rmse) R²: num2str(r2)])跑完代码如果发现预测曲线像心电图试试这三板斧加大LSTM层数、调整注意力维度、增加训练轮次。注意机制有个隐藏福利——可视化注意力权重能清楚看到模型在哪些时间点走神了。这种组合模型特别适合有周期性波动但又存在异常波动的数据比如电力负荷预测、股票价格波动分析。下次遇到时序难题不妨让Attention-LSTM帮你擦亮眼睛找关键。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2418824.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!