最近帮实验室刚入门的师弟复现了西储大学轴承故障的迁移学习代码,本来以为是手到擒来的活,结果还是踩了好几个坑,刚好整理出来给同样摸鱼入门的小伙伴参考

news2026/3/26 18:54:40
一区top轴承故障诊断迁移学习代码复现 故障诊断代码 复现 首先使用一维的cnn对源域和目标域进行特征提取域适应阶段将源域和目标域作为cnn的输入得到特征然后进行边缘概率分布对齐和条件概率分布对齐也就是进行JDA联合对齐。此域适应方法特别适合初学者了解迁移学习的基础知识。 数据预处理1维数据 网络模型1D-CNN-MMD-Coral 数据集西储大学CWRU 准确率99% 网络框架pytorch 结果输出损失曲线图、准确率曲线图、混淆矩阵、tsne图 使用对象初学者 注意此代码是一个在 GPU 上跑的代码有宝子的电脑只支持 cpu只需要将代码修改成只在 cpu 上跑的就行这个项目真的太适合新手练手了用最简单的1D-CNN提取振动信号特征再用JDA做联合域对齐不光能快速跑通还能实打实看懂迁移学习到底在对齐什么东西比那些堆了一堆Transformer的花活代码友好太多。先唠唠整体思路说白了就是三步把西储大学的振动数据分成源域比如负载0的故障样本和目标域比如负载1的故障样本用1D-CNN从两类数据里提取特征用JDA把源域和目标域的特征分布拉到一起让模型在源域学的故障知识能直接用到目标域上全程用PyTorch写的GPU跑起来超快没有GPU也能改CPU版本完全符合大家的需求。第一步数据预处理西储大学的数据集是1维的振动信号我一般会存成npy格式方便加载不用每次都解matlab文件。这里写个自定义的Dataset类新手直接抄就能用import os import torch import numpy as np from torch.utils.data import Dataset, DataLoader class CWRUBearingDataset(Dataset): def __init__(self, data_path, label_path, normalizeTrue): # 加载预处理好的振动数据和标签要是你手里是mat文件用scipy.io.loadmat转一下就行 self.data np.load(data_path) self.label np.load(label_path) # 归一化到0-1区间防止训练的时候loss直接炸上天 if normalize: self.data (self.data - self.data.min()) / (self.data.max() - self.data.min()) # 1D-CNN的输入要求是 [batch, 通道数, 序列长度]我们的信号是单通道所以加个1维度 self.data torch.tensor(self.data, dtypetorch.float32).unsqueeze(1) self.label torch.tensor(self.label, dtypetorch.long) def __len__(self): return len(self.label) def __getitem__(self, idx): return self.data[idx], self.label[idx] # 举个例子加载源域和目标域源域用负载0的数据目标域用负载1的数据 source_dataset CWRUBearingDataset(./data/source_0_load_data.npy, ./data/source_0_load_label.npy) target_dataset CWRUBearingDataset(./data/target_1_load_data.npy, ./data/target_1_load_label.npy) # 重点源域和目标域的batch size必须一致不然特征拼接的时候会报错 source_loader DataLoader(source_dataset, batch_size32, shuffleTrue, drop_lastTrue) target_loader DataLoader(target_dataset, batch_size32, shuffleTrue, drop_lastTrue)碎碎念我一开始就是没设drop_last导致最后一个batch的数据量不一样训练直接崩了血的教训。还有归一化真的很重要没做之前我的loss直接跑到了几十万训不动一点。第二步1D-CNN特征提取网络我写的是超级简单的两层卷积没有搞什么残差或者复杂的结构新手完全能看懂每一层在干嘛import torch.nn as nn class Simple1DCNN(nn.Module): def __init__(self, num_classes10): super().__init__() # 第一层卷积抓小的振动波动特征比如轴承的冲击脉冲 self.conv1 nn.Conv1d(in_channels1, out_channels16, kernel_size3, stride1, padding1) self.relu1 nn.ReLU() self.pool1 nn.MaxPool1d(kernel_size2, stride2) # 池化降维把序列长度砍半 # 第二层卷积抓更复杂的组合特征 self.conv2 nn.Conv1d(in_channels16, out_channels32, kernel_size3, stride1, padding1) self.relu2 nn.ReLU() self.pool2 nn.MaxPool1d(kernel_size2, stride2) # 全连接层把特征压缩到128维再输出64维的特征用来做域对齐 self.fc nn.Linear(32 * 256, 128) self.feature_layer nn.Linear(128, 64) # 最后加个分类头用来算源域的分类损失 self.classifier nn.Linear(64, num_classes) def forward(self, x): x self.pool1(self.relu1(self.conv1(x))) x self.pool2(self.relu2(self.conv2(x))) # 把二维特征展平成一维方便全连接层处理 x x.view(-1, 32 * 256) x self.fc(x) features self.feature_layer(x) logits self.classifier(features) return features, logits # 自动选择GPU/CPU没有GPU就直接用CPU跑 device cuda if torch.cuda.is_available() else cpu model Simple1DCNN(num_classes10).to(device)碎碎念这里的32*256是我假设原始信号长度是1024经过两次池化后变成了1024/2/2256要是你的信号长度不一样记得改这个数值不然会报形状错误。第三步JDA联合域对齐损失这个是迁移学习的核心我简化了原版JDA的代码新手不用纠结复杂的矩阵运算知道它是用来把源域和目标域的特征拉到一起就行一区top轴承故障诊断迁移学习代码复现 故障诊断代码 复现 首先使用一维的cnn对源域和目标域进行特征提取域适应阶段将源域和目标域作为cnn的输入得到特征然后进行边缘概率分布对齐和条件概率分布对齐也就是进行JDA联合对齐。此域适应方法特别适合初学者了解迁移学习的基础知识。 数据预处理1维数据 网络模型1D-CNN-MMD-Coral 数据集西储大学CWRU 准确率99% 网络框架pytorch 结果输出损失曲线图、准确率曲线图、混淆矩阵、tsne图 使用对象初学者 注意此代码是一个在 GPU 上跑的代码有宝子的电脑只支持 cpu只需要将代码修改成只在 cpu 上跑的就行不光对齐整体的特征分布边缘分布还对齐同一个故障类别的特征分布条件分布比单纯的MMD效果好太多。import torch import torch.nn.functional as F def jda_domain_alignment_loss(source_features, target_features, source_labels, num_classes10): # 把源域和目标域的特征拼在一起 all_features torch.cat([source_features, target_features], dim0) # 用高斯核计算样本之间的相似度 gamma 1.0 pairwise_distance torch.cdist(all_features, all_features, p2) ** 2 kernel_matrix torch.exp(-gamma * pairwise_distance) # 构建联合分布的权重矩阵核心就是让同类样本靠近不同域样本拉远 source_batch source_features.size(0) target_batch target_features.size(0) weight_matrix torch.zeros(source_batch target_batch, source_batch target_batch, devicedevice) # 初始化源域和目标域的整体权重 weight_matrix[:source_batch, :source_batch] 1 / (source_batch ** 2) weight_matrix[source_batch:, source_batch:] 1 / (target_batch ** 2) weight_matrix[:source_batch, source_batch:] -1 / (source_batch * target_batch) weight_matrix[source_batch:, :source_batch] -1 / (source_batch * target_batch) # 加上类内对齐的权重让同一个故障类别的源域和目标域特征更靠近 for class_idx in range(num_classes): source_class_idx (source_labels class_idx).nonzero(as_tupleTrue)[0] target_class_idx (source_labels class_idx).nonzero(as_tupleTrue)[0] source_batch if len(source_class_idx) 0 or len(target_class_idx) 0: continue weight_matrix[source_class_idx[:, None], source_class_idx] 1 / (len(source_class_idx) ** 2) weight_matrix[target_class_idx[:, None], target_class_idx] 1 / (len(target_class_idx) ** 2) weight_matrix[source_class_idx[:, None], target_class_idx] - 1 / (len(source_class_idx) * len(target_class_idx)) weight_matrix[target_class_idx[:, None], source_class_idx] - 1 / (len(source_class_idx) * len(target_class_idx)) # 计算最终的JDA损失 loss torch.trace(torch.matmul(torch.matmul(kernel_matrix, weight_matrix), kernel_matrix.T)) return loss第四步完整训练流程把上面的东西拼在一起就是完整的训练循环了import torch.optim as optim import matplotlib.pyplot as plt # 初始化损失函数和优化器 ce_loss_fn nn.CrossEntropyLoss() optimizer optim.Adam(model.parameters(), lr1e-4) # 超参数大家可以自己调调参就是玄学 lambda_jda 0.5 epochs 50 loss_list [] acc_list [] for epoch in range(epochs): model.train() total_train_loss 0.0 total_train_acc 0.0 # 同时遍历源域和目标域的dataloader for (source_x, source_y), (target_x, target_y) in zip(source_loader, target_loader): source_x, source_y source_x.to(device), source_y.to(device) target_x, target_y target_x.to(device), target_y.to(device) # 提取特征和分类结果 source_feat, source_logits model(source_x) target_feat, target_logits model(target_x) # 计算分类损失和域对齐损失 cls_loss ce_loss_fn(source_logits, source_y) align_loss jda_domain_alignment_loss(source_feat, target_feat, source_y, num_classes10) total_loss cls_loss lambda_jda * align_loss # 反向传播更新参数 optimizer.zero_grad() total_loss.backward() optimizer.step() # 统计指标 total_train_loss total_loss.item() pred torch.argmax(source_logits, dim1) total_train_acc (pred source_y).sum().item() / len(source_y) # 打印每个epoch的结果 avg_loss total_train_loss / len(source_loader) avg_acc total_train_acc / len(source_loader) loss_list.append(avg_loss) acc_list.append(avg_acc) print(fEpoch [{epoch1}/{epochs}] | 平均损失: {avg_loss:.4f} | 源域准确率: {avg_acc:.4f})碎碎念要是你只有CPU就把所有的.to(device)删掉或者改成.cpu()跑起来会慢一点我用1050Ti跑50个epoch大概10分钟CPU的话大概半小时左右。第五步结果可视化训练完之后一定要画图看效果不然你都不知道自己训了个啥from sklearn.manifold import TSNE from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix # 1. 画损失和准确率曲线 plt.figure(figsize(12, 4)) plt.subplot(1, 2, 1) plt.plot(loss_list, label训练损失) plt.xlabel(Epoch) plt.ylabel(Loss) plt.legend() plt.subplot(1, 2, 2) plt.plot(acc_list, label源域准确率) plt.xlabel(Epoch) plt.ylabel(Accuracy) plt.legend() plt.savefig(./train_curve.png) # 2. 画混淆矩阵用目标域的数据测试 model.eval() all_pred [] all_true [] with torch.no_grad(): for x, y in target_loader: x, y x.to(device), y.to(device) _, logits model(x) pred torch.argmax(logits, dim1) all_pred.extend(pred.cpu().numpy()) all_true.extend(y.cpu().numpy()) cm confusion_matrix(all_true, all_pred) disp ConfusionMatrixDisplay(confusion_matrixcm, display_labels[正常, 内圈故障0.007, 外圈故障0.007, 滚动体故障]) disp.plot(cmapplt.cm.Blues) plt.savefig(./confusion_matrix.png) # 3. t-SNE可视化特征对齐效果 all_source_feat [] all_source_label [] all_target_feat [] all_target_label [] with torch.no_grad(): for x, y in source_loader: x, y x.to(device), y.to(device) feat, _ model(x) all_source_feat.extend(feat.cpu().numpy()) all_source_label.extend(y.cpu().numpy()) for x, y in target_loader: x, y x.to(device), y.to(device) feat, _ model(x) all_target_feat.extend(feat.cpu().numpy()) all_target_label.extend(y.cpu().numpy()) all_feat np.concatenate([all_source_feat, all_target_feat], axis0) # 把目标域的标签加10和源域区分开 all_label all_source_label [l10 for l in all_target_label] tsne TSNE(n_components2, random_state42) feat_2d tsne.fit_transform(all_feat) plt.figure(figsize(8, 8)) plt.scatter(feat_2d[:len(all_source_feat), 0], feat_2d[:len(all_source_feat), 1], cblue, label源域, alpha0.5) plt.scatter(feat_2d[len(all_source_feat):, 0], feat_2d[len(all_source_feat):, 1], cred, label目标域, alpha0.5) plt.legend() plt.savefig(./tsne_visualization.png)碎碎念t-SNE图真的太直观了对齐好的话源域和目标域的点会混在一起要是没对齐就是两堆分开的颜色比看准确率爽多了。我这次跑出来的目标域准确率能到99%左右和博主说的差不多。最后唠点踩坑经验源域和目标域的batch size必须一致不然特征拼接会报错1D-CNN的输入一定要加通道维度不然会报形状错误标签一定要搞对我一开始把内圈和外圈的标签搞反了准确率直接跌到50%要是训练的时候loss不下降试试把学习率调小一点或者把lambda_jda改大一点完整的代码我已经打包上传到GitHub了需要的小伙伴直接搜1d-cnn-jda-cwru就能找到有啥问题也可以留言问我看到都会回的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2451865.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringBoot-17-MyBatis动态SQL标签之常用标签

文章目录 1 代码1.1 实体User.java1.2 接口UserMapper.java1.3 映射UserMapper.xml1.3.1 标签if1.3.2 标签if和where1.3.3 标签choose和when和otherwise1.4 UserController.java2 常用动态SQL标签2.1 标签set2.1.1 UserMapper.java2.1.2 UserMapper.xml2.1.3 UserController.ja…

wordpress后台更新后 前端没变化的解决方法

使用siteground主机的wordpress网站,会出现更新了网站内容和修改了php模板文件、js文件、css文件、图片文件后,网站没有变化的情况。 不熟悉siteground主机的新手,遇到这个问题,就很抓狂,明明是哪都没操作错误&#x…

网络编程(Modbus进阶)

思维导图 Modbus RTU(先学一点理论) 概念 Modbus RTU 是工业自动化领域 最广泛应用的串行通信协议,由 Modicon 公司(现施耐德电气)于 1979 年推出。它以 高效率、强健性、易实现的特点成为工业控制系统的通信标准。 包…

UE5 学习系列(二)用户操作界面及介绍

这篇博客是 UE5 学习系列博客的第二篇,在第一篇的基础上展开这篇内容。博客参考的 B 站视频资料和第一篇的链接如下: 【Note】:如果你已经完成安装等操作,可以只执行第一篇博客中 2. 新建一个空白游戏项目 章节操作,重…

IDEA运行Tomcat出现乱码问题解决汇总

最近正值期末周,有很多同学在写期末Java web作业时,运行tomcat出现乱码问题,经过多次解决与研究,我做了如下整理: 原因: IDEA本身编码与tomcat的编码与Windows编码不同导致,Windows 系统控制台…

利用最小二乘法找圆心和半径

#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …

使用docker在3台服务器上搭建基于redis 6.x的一主两从三台均是哨兵模式

一、环境及版本说明 如果服务器已经安装了docker,则忽略此步骤,如果没有安装,则可以按照一下方式安装: 1. 在线安装(有互联网环境): 请看我这篇文章 传送阵>> 点我查看 2. 离线安装(内网环境):请看我这篇文章 传送阵>> 点我查看 说明&#xff1a;假设每台服务器已…

XML Group端口详解

在XML数据映射过程中&#xff0c;经常需要对数据进行分组聚合操作。例如&#xff0c;当处理包含多个物料明细的XML文件时&#xff0c;可能需要将相同物料号的明细归为一组&#xff0c;或对相同物料号的数量进行求和计算。传统实现方式通常需要编写脚本代码&#xff0c;增加了开…

LBE-LEX系列工业语音播放器|预警播报器|喇叭蜂鸣器的上位机配置操作说明

LBE-LEX系列工业语音播放器|预警播报器|喇叭蜂鸣器专为工业环境精心打造&#xff0c;完美适配AGV和无人叉车。同时&#xff0c;集成以太网与语音合成技术&#xff0c;为各类高级系统&#xff08;如MES、调度系统、库位管理、立库等&#xff09;提供高效便捷的语音交互体验。 L…

(LeetCode 每日一题) 3442. 奇偶频次间的最大差值 I (哈希、字符串)

题目&#xff1a;3442. 奇偶频次间的最大差值 I 思路 &#xff1a;哈希&#xff0c;时间复杂度0(n)。 用哈希表来记录每个字符串中字符的分布情况&#xff0c;哈希表这里用数组即可实现。 C版本&#xff1a; class Solution { public:int maxDifference(string s) {int a[26]…

【大模型RAG】拍照搜题技术架构速览:三层管道、两级检索、兜底大模型

摘要 拍照搜题系统采用“三层管道&#xff08;多模态 OCR → 语义检索 → 答案渲染&#xff09;、两级检索&#xff08;倒排 BM25 向量 HNSW&#xff09;并以大语言模型兜底”的整体框架&#xff1a; 多模态 OCR 层 将题目图片经过超分、去噪、倾斜校正后&#xff0c;分别用…

【Axure高保真原型】引导弹窗

今天和大家中分享引导弹窗的原型模板&#xff0c;载入页面后&#xff0c;会显示引导弹窗&#xff0c;适用于引导用户使用页面&#xff0c;点击完成后&#xff0c;会显示下一个引导弹窗&#xff0c;直至最后一个引导弹窗完成后进入首页。具体效果可以点击下方视频观看或打开下方…

接口测试中缓存处理策略

在接口测试中&#xff0c;缓存处理策略是一个关键环节&#xff0c;直接影响测试结果的准确性和可靠性。合理的缓存处理策略能够确保测试环境的一致性&#xff0c;避免因缓存数据导致的测试偏差。以下是接口测试中常见的缓存处理策略及其详细说明&#xff1a; 一、缓存处理的核…

龙虎榜——20250610

上证指数放量收阴线&#xff0c;个股多数下跌&#xff0c;盘中受消息影响大幅波动。 深证指数放量收阴线形成顶分型&#xff0c;指数短线有调整的需求&#xff0c;大概需要一两天。 2025年6月10日龙虎榜行业方向分析 1. 金融科技 代表标的&#xff1a;御银股份、雄帝科技 驱动…

观成科技:隐蔽隧道工具Ligolo-ng加密流量分析

1.工具介绍 Ligolo-ng是一款由go编写的高效隧道工具&#xff0c;该工具基于TUN接口实现其功能&#xff0c;利用反向TCP/TLS连接建立一条隐蔽的通信信道&#xff0c;支持使用Let’s Encrypt自动生成证书。Ligolo-ng的通信隐蔽性体现在其支持多种连接方式&#xff0c;适应复杂网…

铭豹扩展坞 USB转网口 突然无法识别解决方法

当 USB 转网口扩展坞在一台笔记本上无法识别,但在其他电脑上正常工作时,问题通常出在笔记本自身或其与扩展坞的兼容性上。以下是系统化的定位思路和排查步骤,帮助你快速找到故障原因: 背景: 一个M-pard(铭豹)扩展坞的网卡突然无法识别了,扩展出来的三个USB接口正常。…

未来机器人的大脑:如何用神经网络模拟器实现更智能的决策?

编辑&#xff1a;陈萍萍的公主一点人工一点智能 未来机器人的大脑&#xff1a;如何用神经网络模拟器实现更智能的决策&#xff1f;RWM通过双自回归机制有效解决了复合误差、部分可观测性和随机动力学等关键挑战&#xff0c;在不依赖领域特定归纳偏见的条件下实现了卓越的预测准…

Linux应用开发之网络套接字编程(实例篇)

服务端与客户端单连接 服务端代码 #include <sys/socket.h> #include <sys/types.h> #include <netinet/in.h> #include <stdio.h> #include <stdlib.h> #include <string.h> #include <arpa/inet.h> #include <pthread.h> …

华为云AI开发平台ModelArts

华为云ModelArts&#xff1a;重塑AI开发流程的“智能引擎”与“创新加速器”&#xff01; 在人工智能浪潮席卷全球的2025年&#xff0c;企业拥抱AI的意愿空前高涨&#xff0c;但技术门槛高、流程复杂、资源投入巨大的现实&#xff0c;却让许多创新构想止步于实验室。数据科学家…

深度学习在微纳光子学中的应用

深度学习在微纳光子学中的主要应用方向 深度学习与微纳光子学的结合主要集中在以下几个方向&#xff1a; 逆向设计 通过神经网络快速预测微纳结构的光学响应&#xff0c;替代传统耗时的数值模拟方法。例如设计超表面、光子晶体等结构。 特征提取与优化 从复杂的光学数据中自…