Day09【基于jieba分词和RNN实现的简单中文分词】

news2025/5/25 16:43:44

基于jieba分词和RNN实现的中文分词

      • 目标
      • 数据准备
      • 主程序
      • 预测效果

在这里插入图片描述
在这里插入图片描述

目标

本文基于给定的中文词表,将输入的文本基于jieba分词分割为若干个词,词的末尾对应的标签为1,中间部分对应的标签为0,同时将分词后的单词基于中文词表做初步序列化,之后经过embeddingRNN循环神经网络等网络结构层,最后输出在两类别(词内部和词边界)标签上的概率分布,从而实现一个简单中文分词任务。

数据准备

词表文件chars.txt

中文语料文件corpus.txt中文语料文件

主程序

#coding:utf8

import torch
import torch.nn as nn
import jieba
import numpy as np
import random
import json
from torch.utils.data import DataLoader

"""
基于pytorch的网络编写一个分词模型
我们使用jieba分词的结果作为训练数据
看看是否可以得到一个效果接近的神经网络模型
"""

class TorchModel(nn.Module):
    def __init__(self, input_dim, hidden_size, num_rnn_layers, vocab):
        super(TorchModel, self).__init__()
        self.embedding = nn.Embedding(len(vocab) + 1, input_dim) #shape=(vocab_size, dim)
        self.rnn_layer = nn.RNN(input_size=input_dim,
                            hidden_size=hidden_size,
                            batch_first=True,
                            bidirectional=False,
                            num_layers=num_rnn_layers,
                            nonlinearity="relu",
                            dropout=0.1)
        self.classify = nn.Linear(hidden_size, 2)
        self.loss_func = nn.CrossEntropyLoss(ignore_index=-100)


    #当输入真实标签,返回loss值;无真实标签,返回预测值
    def forward(self, x, y=None):
        x = self.embedding(x)  #output shape:(batch_size, sen_len, input_dim)
        x, _ = self.rnn_layer(x)  #output shape:(batch_size, sen_len, hidden_size)
        y_pred = self.classify(x)   #input shape:(batch_size, sen_len, class_num)
        if y is not None:
            #(batch_size * sen_len, class_num),   (batch_size * sen_len, 1)
            return self.loss_func(y_pred.view(-1, 2), y.view(-1))
        else:
            return y_pred

class Dataset:
    def __init__(self, corpus_path, vocab, max_length):
        self.vocab = vocab
        self.corpus_path = corpus_path
        self.max_length = max_length
        self.load()

    def load(self):
        self.data = []
        with open(self.corpus_path, encoding="utf8") as f:
            for line in f:
                sequence = sentence_to_sequence(line, self.vocab)
                label = sequence_to_label(line)
                sequence, label = self.padding(sequence, label)
                sequence = torch.LongTensor(sequence)
                label = torch.LongTensor(label)
                self.data.append([sequence, label])
                if len(self.data) > 10000:
                    break

    def padding(self, sequence, label):
        sequence = sequence[:self.max_length]
        sequence += [0] * (self.max_length - len(sequence))
        label = label[:self.max_length]
        label += [-100] * (self.max_length - len(label))
        return sequence, label

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]

#文本转化为数字序列,为embedding做准备
def sentence_to_sequence(sentence, vocab):
    sequence = [vocab.get(char, vocab['unk']) for char in sentence]
    return sequence

#基于结巴生成分级结果的标注
def sequence_to_label(sentence):
    words = jieba.lcut(sentence)
    label = [0] * len(sentence)
    pointer = 0
    for word in words:
        pointer += len(word)
        label[pointer - 1] = 1
    return label

#加载字表
def build_vocab(vocab_path):
    vocab = {}
    with open(vocab_path, "r", encoding="utf8") as f:
        for index, line in enumerate(f):
            char = line.strip()
            vocab[char] = index + 1   #每个字对应一个序号
    vocab['unk'] = len(vocab) + 1
    return vocab

#建立数据集
def build_dataset(corpus_path, vocab, max_length, batch_size):
    dataset = Dataset(corpus_path, vocab, max_length) #diy __len__ __getitem__
    data_loader = DataLoader(dataset, shuffle=True, batch_size=batch_size) #torch
    return data_loader


def main():
    epoch_num = 10        #训练轮数
    batch_size = 20       #每次训练样本个数
    char_dim = 50         #每个字的维度
    hidden_size = 100     #隐含层维度
    num_rnn_layers = 3    #rnn层数
    max_length = 20       #样本最大长度
    learning_rate = 1e-3  #学习率
    vocab_path = "chars.txt"  #字表文件路径
    corpus_path = "corpus.txt"  #语料文件路径
    vocab = build_vocab(vocab_path)       #建立字表
    data_loader = build_dataset(corpus_path, vocab, max_length, batch_size)  #建立数据集
    model = TorchModel(char_dim, hidden_size, num_rnn_layers, vocab)   #建立模型
    optim = torch.optim.Adam(model.parameters(), lr=learning_rate)     #建立优化器
    #训练开始
    for epoch in range(epoch_num):
        model.train()
        watch_loss = []
        for x, y in data_loader:
            optim.zero_grad()    #梯度归零
            loss = model(x, y)   #计算loss
            loss.backward()      #计算梯度
            optim.step()         #更新权重
            watch_loss.append(loss.item())
        print("=========\n第%d轮平均loss:%f" % (epoch + 1, np.mean(watch_loss)))
    #保存模型
    torch.save(model.state_dict(), "model.pth")
    #保存词表
    writer = open("vocab.json", "w", encoding="utf8")
    writer.write(json.dumps(vocab, ensure_ascii=False, indent=2))
    writer.close()
    return

#最终预测
def predict(model_path, vocab_path, input_strings):
    #配置保持和训练时一致
    char_dim = 50  # 每个字的维度
    hidden_size = 100  # 隐含层维度
    num_rnn_layers = 3  # rnn层数
    vocab = build_vocab(vocab_path)       #建立字表
    model = TorchModel(char_dim, hidden_size, num_rnn_layers, vocab)   #建立模型
    model.load_state_dict(torch.load(model_path))   #加载训练好的模型权重
    model.eval()
    for input_string in input_strings:
        #逐条预测
        x = sentence_to_sequence(input_string, vocab)
        # print(x)
        with torch.no_grad():
            result = model.forward(torch.LongTensor([x]))[0]
            result = torch.argmax(result, dim=-1)  #预测出的01序列
            print(result)
            #在预测为1的地方切分,将切分后文本打印出来
            for index, p in enumerate(result):
                if p == 1:
                    print(input_string[index], end=" ")
                else:
                    print(input_string[index], end="")
            print()



if __name__ == "__main__":
    print(torch.backends.mps.is_available())
    main()

    input_strings = ["同时国内有望出台新汽车刺激方案",
                     "沪胶后市有望延续强势",
                     "经过两个交易日的强势调整后",
                     "昨日上海天然橡胶期货价格再度大幅上扬"]
    predict("model.pth", "chars.txt", input_strings)

主要实现了一个基于jieba分词的中文分词模型,模型采用 RNN(循环神经网络)来处理中文文本,通过对句子进行分词,预测每个字是否为词的结尾。具体内容如下:

  1. 模型结构(TorchModel

    • 使用 nn.Embedding 层将每个字符映射到一个高维空间。
    • 通过 nn.RNN 层处理字符序列,提取上下文信息,使用单向 RNNbidirectional=False)。
    • 最后通过 nn.Linear 层将 RNN 输出转化为每个字符的分类结果,分类为 0(非词结尾)或 1(词结尾)。
    • 损失函数为 CrossEntropyLoss,计算预测与真实标签的差异。
  2. 数据处理(Dataset

    • 使用 jieba 分词工具将文本切分为词,并为每个字符标注一个标签。标签为 1 表示该字符是词的结尾,0 表示不是词结尾。
    • 将文本转换为数字序列,并根据最大句子长度进行填充,使得输入数据的形状一致。
  3. 训练过程

    • 数据通过 DataLoader 按批加载,使用 Adam 优化器进行训练。
    • 在每一轮训练中,计算损失并通过反向传播优化模型权重,训练 10 轮。
  4. 预测功能(predict

    • 加载训练好的模型,使用 torch.no_grad() 禁用梯度计算,提高推理速度。
    • 对每个输入字符串进行分词预测,输出每个字是否为词的结尾。若为词结尾,则切分该词并打印。
  5. 核心流程

    • sentence_to_sequence 将文本转换为字符序列,sequence_to_label 生成对应的标签序列。
    • 训练完成后,保存模型和词表,以便后续加载和预测。

代码实现了一个简单的中文分词模型,通过标注每个字符是否为词的结尾,结合 RNN 提取上下文信息,从而实现文本分词功能。

预测效果

输入语句:
“同时国内有望出台新汽车刺激方案”,
“沪胶后市有望延续强势”,
“经过两个交易日的强势调整后”,
“昨日上海天然橡胶期货价格再度大幅上扬”

中文分词后结果:

tensor([0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1])
同时 国内 有 望 出台 新 汽车 刺激 方案 
tensor([1, 1, 0, 1, 1, 1, 0, 1, 0, 1])
沪 胶 后市 有 望 延续 强势 
tensor([0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1])
经过 两个 交易 日 的 强势 调整 后 
tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1])
昨日 上海 天然 橡胶 期货 价格 再度 大幅 上扬 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2336541.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

自动化测试——selenium

简介 Selenium 是一个广泛使用的自动化测试工具,主要用于 Web 应用程序的自动化测试。它能实现的功能是网页的自动化操作,例如自动抢票刷课等。同时你应该也见到过有些网站在打开之后并没有直接加载出网站的所有内容,比如一些图片等等&#x…

Linux——进程通信

我们知道,进程具有独立性,各进程之间互不干扰,但我们为什么还要让其联系,建立通信呢?比如:数据传输,资源共享,通知某个事件,或控制某个进程。因此,让进程间建…

【免费参会合集】2025年生物制药行业展会会议表格整理

全文精心整理, 建议今年参会前都好好收藏着,记得点赞! 医药人非常吃资源,资源从何而来?作为一名从事医药行业的工作者,可以很负责任的告诉诸位,其中非常重要的一个渠道就是会议会展! 建议所有医…

腾讯云开发+MCP:旅游规划攻略

1.登录注册好之后进入腾讯云开发 2.创建环境 4.创建好环境之后点击去开发 5.进入控制台后,选择AI,找到MCP 6.点击创建MCP Server 使用腾讯云开发创建MCP目前需要云开发入门版99/月,我没开通,所以没办法往下进行。

Sklearn入门之数据预处理preprocessing

、 Sklearn全称:Scipy-toolkit Learn是 一个基于scipy实现的的开源机器学习库。它提供了大量的算法和工具,用于数据挖掘和数据分析,包括分类、回归、聚类等多种任务。本文我将带你了解并入门Sklearn下的preprocessing在机器学习中的基本用法。 获取方式…

家用打印机性价比排名及推荐

文章目录 品牌性价比一、核心参数对比与场景适配二、技术类型深度解析三、不同场景选择 相关文章 品牌 性价比 一、核心参数对比与场景适配 兄弟T436W 优势: 微压电技术,打印头寿命长,堵头率低。 支持A4无边距和5G WiFi,适合照片…

数字电子技术基础(四十七)——使用Mutlisim软件来模拟74LS85芯片

目录 1 使用74LS85N芯片完成四位二进制数的比较 1.1原理介绍 1.2 器件选择 1.3 运行电路 2 使用74LS85N完成更多位的二进制比较 1 使用74LS85N芯片完成四位二进制数的比较 1.1原理介绍 对于74LS85 是一款 4 位数值比较器集成电路,用于比较两个 4 位二进制数&…

关于STM32创建工程文件启动文件选择

注意启动文件只要选择这几个 而不是要把所有都选上

LLC电路工作在容性区的风险

在t0时刻之前,Q6Q7导通,回路如下所示,此时A点电压是低压,B点电压是高压 在t0时刻时,谐振电流相位发生变换,在t1时刻,Q5,Q8导通,对于Q8MOS管来说,B点电压在Q6Q…

Linux Kernel 6

clone 系统调用(The clone system call) 在 Linux 中,使用 clone() 系统调用来创建新的线程或进程。fork() 系统调用和 pthread_create() 函数都基于 clone() 的实现。 clone() 系统调用允许调用者决定哪些资源应该与父进程共享&#xff0c…

【开源项目】Excel手撕AI算法深入理解(四):AlphaFold、Autoencoder

项目源码地址:https://github.com/ImagineAILab/ai-by-hand-excel.git 一、AlphaFold AlphaFold 是 DeepMind 开发的突破性 AI 算法,用于预测蛋白质的三维结构。它的出现解决了生物学领域长达 50 年的“蛋白质折叠问题”,被《科学》杂志评为…

第IV部分有效应用程序的设计模式

第IV部分有效应用程序的设计模式 第IV部分有效应用程序的设计模式第23章:应用程序用户界面的架构设计23.1设计考量23.2示例1:用于非分布式有界上下文的一个基于HTMLAF的、服务器端的UI23.3示例2:用于分布式有界上下文的一个基于数据API的客户端UI23.4要点第24章:CQRS:一种…

如何编制实施项目管理章程

本文档概述了一个项目管理系统的实施计划,旨在通过统一的业务规范和技术架构,加强集团公司的业务管控,并规范业务管理。系统建设将遵循集团统一模板,确保各单位项目系统建设的标准化和一致性。 实施范围涵盖投资管理、立项管理、设计管理、进度管理等多个方面,支持项目全生…

排序(java)

一.概念 排序:对一组数据进行从小到大/从大到小的排序 稳定性:即使进行排序相对位置也不受影响如: 如果再排序后 L 在 i 的前面则稳定性差,像图中这样就是稳定性好。 二.常见的排序 三.常见算法的实现 1.插入排序 1.1 直…

【HDFS入门】HDFS副本策略:深入浅出副本机制

目录 1 HDFS副本机制概述 2 HDFS副本放置策略 3 副本策略的优势 4 副本因子配置 5 副本管理流程 6 最佳实践与调优 7 总结 1 HDFS副本机制概述 Hadoop分布式文件系统(HDFS)的核心设计原则之一就是通过数据冗余来保证可靠性,而这一功能正是通过副本策略实现的…

智能 GitHub Copilot 副驾驶® 更新升级!

智能 GitHub Copilot 副驾驶 迎来重大升级!现在,所有 VS Code 用户都能体验支持 Multi-Context Protocol(MCP)的全新 Agent Mode。此外,微软还推出了智能 GitHub Copilot 副驾驶 Pro 订阅计划,提供更强大的…

【今日三题】添加字符(暴力枚举) / 数组变换(位运算) / 装箱问题(01背包)

⭐️个人主页:小羊 ⭐️所属专栏:每日两三题 很荣幸您能阅读我的文章,诚请评论指点,欢迎欢迎 ~ 目录 添加字符(暴力枚举)数组变换(位运算)装箱问题(01背包) 添加字符(暴力枚举) 添加字符 当在A的开头或结尾添加字符直到和B长度…

Linux——消息队列

目录 一、消息队列的定义 二、相关函数 2.1 msgget 函数 2.2 msgsnd 函数 2.3 msgrcv 函数 2.4 msgctl 函数 三、消息队列的操作 3.1 创建消息队列 3.2 获取消息队列并发送消息 3.3 从消息队列接收消息recv 四、 删除消息队列 4.1 ipcrm 4.2 msgctl函数 一、消息…

领慧立芯LHE7909可兼容替代TI的ADS1299

LHE7909是一款由领慧立芯(Legendsemi)推出的24位高精度Δ-Σ模数转换器(ADC),主要面向医疗电子和生物电势测量应用,如脑电图(EEG)、心电图(ECG)等设备。以下是…

MongoDB简单用法

图片中 MongoDB Compass 中显示了默认的三个数据库: adminconfiglocal 如果在 .env 文件中配置的是: MONGODB_URImongodb://admin:passwordlocalhost:27017/ MONGODB_NAMERAGSAAS💡 一、为什么 Compass 里没有 RAGSAAS 数据库?…