【医药AI实战系列⑤】分布漂移、稀疏标签、冷启动——工业级DDI系统的三重死亡陷阱(附:如何设计让药剂师真正信任的可解释性输出)
先说结论,再说过程我们的DDI(Drug-Drug Interaction,药物相互作用)预测系统,在内部测试集上AUC 0.91,上线三个月后真实场景的AUC只有0.79。差了0.12。这不是小差距。AUC从0.91掉到0.79,意味着模型对真实临床场景的判断能力,比我们以为的差了将近一个数量级。有几次,系统对某个药物组合给出了"低风险"的预测,而临床药剂师的判断是"必须干预"。没有出事,是因为我们的系统设计里有人工兜底。但这件事让我们停下来,彻底重新审视了整个系统的设计逻辑。这篇文章记录我们找到的三个根本原因,以及怎么修。背景:DDI预测是个什么问题药物相互作用(DDI)是临床用药安全的核心议题。两种或多种药物同时使用时,可能产生以下后果:药效增强 → 毒性风险上升(如:华法林 + 阿司匹林 → 出血风险增加) 药效降低 → 治疗失败(如:利福平 + 口服避孕药 → 避孕失效) 产生新毒性 → 器官损伤(如:他汀类 + 环孢素 → 横纹肌溶解风险)全球上市药物约2万种,理论上两两组合的DDI对数约为2亿个。人类无法穷举实验,预测模型因此有真实价值。传统方法依赖药物手册和规则库(如Micromedex、DrugBank),覆盖有限且更新滞后。GNN(图神经网络)的引入,是因为药物-靶点-通路-不良反应天然构成一个图结构,适合用图学习来捕捉复杂的关联关系。系统架构:我们当初是怎么设计的数据层训练数据来自三个来源:DrugBank 5.1 → 已知DDI对(约300,000条有标注的相互作用) SIDER 4.1 → 药物副作用数据库(约1,430个药物) STITCH 5.0 → 药物-蛋白质相互作用(约500,000条)把这三个数据源整合成一个异构图:节点类型: - 药物节点(Drug):~8,000个 - 靶点节点(Target/Protein):~4,500个 - 通路节点(Pathway):~1,200个 - 副作用节点(Side Effect):~5,700个 边类型: - Drug → Drug:DDI关系(有标注,正负样本) - Drug → Target:药物-靶点结合 - Drug → Side Effect:药物-副作用关联 - Target → Pathway:靶点-通路归属模型层我们用的是R-GCN(Relational Graph Convolutional Network),能处理多种边类型:importtorchimporttorch.nnasnnimporttorch.nn.functionalasFfromtorch_geometric.nnimportRGCNConvclassDDI_RGCN(nn.Module):""" 基于R-GCN的药物相互作用预测模型 处理异构图中的多种关系类型 """def__init__(self,num_nodes:int,num_relations:int,embedding_dim:int=128,hidden_dim:int=256,num_classes:int=86,# DDI类型数量(基于DeepDDI分类)dropout:float=0.3):super(DDI_RGCN,self).__init__()# 节点嵌入(所有节点类型共享嵌入空间)self.node_embedding=nn.Embedding(num_nodes,embedding_dim)# R-GCN层(两层)self.conv1=RGCNConv(in_channels=embedding_dim,out_channels=hidden_dim,num_relations=num_relations,num_bases=30# basis decomposition,减少参数量)self.conv2=RGCNConv(in_channels=hidden_dim,out_channels=hidden_dim,num_relations=num_relations,num_bases=30)# DDI预测头(基于两个药物节点的嵌入)self.predictor=nn.Sequential(nn.Linear(hidden_dim*2,hidden_dim),nn.ReLU(),nn.Dropout(dropout),nn.Linear(hidden_dim,hidden_dim//2),nn.ReLU(),nn.Dropout(dropout),nn.Linear(hidden_dim//2,num_classes))self.dropout=nn.Dropout(dropout)defencode(self,x,edge_index,edge_type):"""生成所有节点的图嵌入"""h=self.node_embedding(x)h=F.relu(self.conv1(h,edge_index,edge_type))h=self.dropout(h)h=F.relu(self.conv2(h,edge_index,edge_type))returnhdefdecode(self,z,drug_i_idx,drug_j_idx):""" 给定两个药物节点的嵌入,预测DDI类型 drug_i_idx, drug_j_idx: 药物节点在图中的索引 """z_i=z[drug_i_idx]z_j=z[drug_j_idx]# 拼接两个药物的嵌入pair_embedding=torch.cat([z_i,z_j],dim=-1)# 预测DDI类型的概率分布logits=self.predictor(pair_embedding)returnlogitsdefforward(self,x,edge_index,edge_type,drug_pairs):z=self.encode(x,edge_index,edge_type)drug_i_idx=drug_pairs[:,0]drug_j_idx=drug_pairs[:,1]logits=self.decode(z,drug_i_idx,drug_j_idx)returnlogits# 训练循环(简化版)deftrain_epoch(model,optimizer,data,drug_pairs_train,labels_train):model.train()optimizer.zero_grad()logits=model(data.node_ids,data.edge_index,data.edge_type,drug_pairs_train)loss=F.cross_entropy(logits,labels_train)loss.backward()optimizer.step()returnloss.item()评估结果(离线,当时很满意)数据集分割:随机8:1:1(训练/验证/测试) 测试集结果: AUC-ROC: 0.912 AUC-PR: 0.887 Accuracy: 84.3% Macro-F1: 0.831 按DDI严重程度分类: 严重DDI: AUC 0.934 中度DDI: AUC 0.901 轻度DDI: AUC 0.889看起来非常好。然后我们上线了。死亡陷阱一:分布漂移——随机分割是个谎言上线后AUC掉到0.79,第一个怀疑对象是分布漂移。我们花了两周时间做根因分析,结论很残忍:我们的测试集和训练集共享了大量的药物节点。随机8:1:1分割,分割的是"药物对",不是"药物"。这意味着,如果药物A在训练集里出现过(和药物B、C、D等的相互作用都被训练过),那么在测试集里预测药物A和药物E的相互作用时,模型已经充分学习了A的图嵌入。这在学术上叫做"数据泄露(Data Leakage)",但更准确的描述是:我们测试的根本不是模型的泛化能力,而是它的记忆能力。真实上线场景中,新上市的药物(新节点)完全不在训练图里。模型对它们的预测,本质上是
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2524915.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!