中文新闻文本分类实战:从 TextCNN → BiLSTM → BERT 三档方案对比(附完整代码)
任务中文新闻文本分类如 THUCNews10/14 类目标给出可直接复现的三种主流方案实现 对比1. 数据准备以 THUCNews 为例每行label \t textimporttorchfromtorch.utils.dataimportDatasetclassNewsDataset(Dataset):def__init__(self,path,tokenizerNone,max_len128):self.samples[]withopen(path,encodingutf-8)asf:forlineinf:y,xline.strip().split(\t)self.samples.append((int(y),x))self.tokenizertokenizer self.max_lenmax_lendef__len__(self):returnlen(self.samples)def__getitem__(self,idx):y,xself.samples[idx]ifself.tokenizer:encself.tokenizer(x,truncationTrue,paddingmax_length,max_lengthself.max_len,return_tensorspt)returnenc[input_ids].squeeze(0),enc[attention_mask].squeeze(0),yelse:returnx,y方案一TextCNN快、基线importtorch.nnasnnimporttorch.nn.functionalasFclassTextCNN(nn.Module):def__init__(self,vocab_size,embed_dim,num_classes):super().__init__()self.embednn.Embedding(vocab_size,embed_dim)self.convsnn.ModuleList([nn.Conv2d(1,100,(k,embed_dim))forkin[3,4,5]])self.fcnn.Linear(300,num_classes)defforward(self,x):xself.embed(x)# (B, L, D)xx.unsqueeze(1)# (B, 1, L, D)x[F.relu(conv(x)).squeeze(3)forconvinself.convs]x[F.max_pool1d(i,i.size(2)).squeeze(2)foriinx]xtorch.cat(x,1)returnself.fc(x)特点实现短、训练快适合作为课程项目 baseline。方案二BiLSTM序列建模classBiLSTM(nn.Module):def__init__(self,vocab_size,embed_dim,hidden_dim,num_classes):super().__init__()self.embednn.Embedding(vocab_size,embed_dim)self.lstmnn.LSTM(embed_dim,hidden_dim,batch_firstTrue,bidirectionalTrue)self.fcnn.Linear(hidden_dim*2,num_classes)defforward(self,x):xself.embed(x)_,(h,_)self.lstm(x)htorch.cat((h[-2],h[-1]),dim1)returnself.fc(h)特点比 CNN 更能捕捉上下文但训练稍慢。方案三BERT效果最好fromtransformersimportBertTokenizer,BertModelclassBertClassifier(nn.Module):def__init__(self,num_classes):super().__init__()self.bertBertModel.from_pretrained(bert-base-chinese)self.fcnn.Linear(768,num_classes)defforward(self,input_ids,attention_mask):outself.bert(input_idsinput_ids,attention_maskattention_mask)clsout.last_hidden_state[:,0]returnself.fc(cls)训练代码通用deftrain(model,dataloader,optimizer,device):model.train()loss_fnnn.CrossEntropyLoss()forbatchindataloader:optimizer.zero_grad()iflen(batch)3:# BERTx,mask,y[b.to(device)forbinbatch]logitsmodel(x,mask)else:x,y[b.to(device)forbinbatch]logitsmodel(x)lossloss_fn(logits,y)loss.backward()optimizer.step()三种方案对比模型 实现难度 速度 效果TextCNN ⭐ ⭐⭐⭐ ⭐⭐BiLSTM ⭐⭐ ⭐⭐ ⭐⭐⭐BERT ⭐⭐⭐ ⭐ ⭐⭐⭐⭐总结想快速交付课程项目 → TextCNN想体现“序列建模” → BiLSTM想冲效果/论文复现 → BERT
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2511949.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!