# 构建和训练一个简单的CBOW词嵌入模型

news2025/5/10 22:02:08

构建和训练一个简单的CBOW词嵌入模型

在自然语言处理(NLP)领域,词嵌入是一种将词汇映射到连续向量空间的技术,这些向量能够捕捉词汇之间的语义关系。在这篇文章中,我们将构建和训练一个简单的Continuous Bag of Words(CBOW)词嵌入模型,使用PyTorch框架。
在这里插入图片描述

简介

CBOW模型是一种预测给定上下文中目标词的模型。它通过学习上下文词的向量表示来预测目标词。这种方法在处理大量文本数据时非常有效,因为它可以捕捉词汇之间的语义和语法关系。

环境准备

在开始之前,确保你的环境中安装了以下库:

  • PyTorch
  • NumPy
  • tqdm(用于显示进度条)

如果未安装,可以通过以下命令安装:

pip install torch numpy tqdm

数据准备

我们将使用一个简单的文本语料库来训练我们的模型。这个语料库包含一些句子,我们将从中提取词汇和构建训练数据集。

raw_text = """We are about to study the idea of a computational process.
Computational processes are abstract beings that inhabit computers.
As they evolve, processes manipulate other abstract things called data.
The evolution of a process is directed by a pattern of rules
called a program. People create programs to direct processes. In effect,
we conjure the spirits of the computer with our spells."""

数据预处理

首先,我们需要将文本转换为模型可以理解的格式。这包括创建词汇表、单词到索引的映射以及构建训练数据集。

vocab= set(raw_text)  # 集合。词库,里面内容就是无人
vocab_size = len(vocab) # 计算词汇集合的大小

word_to_idx = {word: i for i, word in enumerate(vocab)}  # for 循环的复合写法。第1次循环,i得到的索引号, word 取1个单词
idx_to_word = {i: word for i, word in enumerate(vocab)}

data = []  # 获取上下文词, 将上下文词作为输入, 目标词作为输出。构建训练数据集。
# 遍历文本,提取上下文和目标词对,构建训练数据集
for i in range(CONTEXT_SIZE, len(raw_text) - CONTEXT_SIZE):  # (2, 60)。
    context = (
        [raw_text[i - (2-j)] for j in range(CONTEXT_SIZE)] + # 获取左边的上下文词
        [raw_text[i + j + 1] for j in range(CONTEXT_SIZE)]   # 获取右边的上下文词
    )  # 获取上下文词 ['we','are','to','study']
    target = raw_text[i]  # 获取目标词 'about'
    data.append((context, target))  # 将上下文词和目标词保存到data中[((['we','are','to','study']),'about')]

模型定义

接下来,我们定义CBOW模型。这个模型包括嵌入层、投影层和输出层。

# 定义CBOW模型类
class CBOW(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(CBOW, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim) # 定义嵌入层
        self.proj = nn.Linear(embedding_dim, 128) # 定义投影层
        self.output = nn.Linear(128, vocab_size) # 定义输出层

    def forward(self, inputs):
        embeds = sum(self.embeddings(inputs)).view(1, -1) # 计算上下文词的嵌入向量的平均值
        out = F.relu(self.proj(embeds)) # 通过ReLU激活函数
        out = self.output(out) # 通过输出层
        nll_prob = F.log_softmax(out, dim=1) # 计算负对数似然损失
        return nll_prob

训练模型

现在,我们训练模型。我们将使用Adam优化器和负对数似然损失函数。

# 确定设备(GPU或CPU)
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(device)
# 初始化模型并将其移动到指定设备
model = CBOW(vocab_size, 10).to(device)

# 初始化优化器
optimizer = optim.Adam(model.parameters(), lr=0.001) # 使用Adam优化器

# 初始化损失列表
losses = []
# 定义损失函数
loss_function = nn.NLLLoss()
# 设置模型为训练模式
model.train()
# 训练模型100个epoch
for epoch in tqdm(range(100)): # 开始训练
    total_loss = 0
    # 遍历所有数据
    for context, target in data:
        context_vector = make_context_vector(context, word_to_idx).to(device)
        target = torch.tensor([word_to_idx[target]]).to(device)
        # 开始前向传播
        train_predict = model(context_vector)
        loss = loss_function(train_predict, target)
        # 反向传播
        optimizer.zero_grad()  # 梯度值清零
        loss.backward() # 反向传播计算得到每个参数的梯度值
        optimizer.step() #根据梯度更新网络参数

        total_loss += loss.item()
    # 记录每个epoch的损失
    losses.append(total_loss)
    print(losses)

测试模型

最后,我们测试模型,看看它如何预测给定上下文的下一个词。

# 定义一个上下文列表,包含四个词
context = ['People', 'create', 'to', 'direct']

# 将上下文词转换为模型可以理解的向量形式
context_vector = make_context_vector(context, word_to_idx).to(device)

# 将模型设置为评估模式,在评估模式下,模型的行为会有所改变,例如不会应用Dropout等
model.eval()

# 使用模型进行预测,传入上下文向量
predict = model(context_vector)

# 获取预测结果中概率最高的索引值,这个索引值对应预测的下一个词
max_idx = predict.argmax(1)

# 打印出输入的上下文
print('context', context)

# 使用索引到单词的映射字典,将预测的索引值转换为对应的单词,并打印出来
print("Predicted next word:", idx_to_word[max_idx.item()])

运行结果

4cf4867c49b51598c4ac.png)

结论

通过这篇文章,我们构建了一个简单的CBOW词嵌入模型,并使用PyTorch框架进行了训练和测试。这个模型能够学习词汇的向量表示,并预测给定上下文的下一个词。这对于许多NLP任务,如文本分类、情感分析等,都是非常有用的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2341373.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Collection集合,List集合,set集合,Map集合

文章目录 集合框架认识集合集合体系结构Collection的功能常用功能三种遍历方式三种遍历方式的区别 List集合List的特点、特有功能ArrayList底层原理LinkedList底层原理LinkedList的应用场list:电影信息管理模块案例 Set集合set集合使用哈希值红黑树HashSet底层原理HashSet集合元…

使用DDR4控制器实现多通道数据读写(九)

一、本章概括 在上一节中,我们概括了工程的整体思路,并提供了工程框架,给出了读写DDR4寄存器的接口列表和重点时序图。当然,对于将DDR4内存封装成FIFO接口,其中的重点在于对于读写DDR4内存地址的控制,相对于…

深度解析n8n全自动AI视频生成与发布工作流

工作流模版地址:Fully Automated AI Video Generation & Multi-Platform Publishing | n8n workflow template 本文将全面剖析基于n8n平台的这个"全自动AI视频生成与多平台发布"工作流的技术架构、实现原理和关键节点,帮助开发者深入理解…

pycharm调试typescript

前言 搜索引擎搜索调试typescript,都是vscode,但是没看懂。 vscode界面简洁,但是适配起来用不习惯,还是喜欢用pycharm。 安装软件 安装Node.js https://nodejs.org/zh-cn 判断是否安装成功 node -v npm install -g typescrip…

spring-ai之Advisors API

1、 Spring AI Advisors API 提供了一种灵活而强大的方法来拦截、 修改和增强 Spring 应用程序中的 AI 驱动的交互。 通过利用 Advisors API,开发人员可以创建更复杂、可重用和可维护的 AI 组件。主要优势包括封装重复的生成式 AI 模式、转换发送到大型语言模型 &…

JVM 系列:JVM 内存结构深度解析

你点赞了吗?你关注了吗?每天分享干货好文。 高并发解决方案与架构设计。 海量数据存储和性能优化。 通用框架/组件设计与封装。 如何设计合适的技术架构? 如何成功转型架构设计与技术管理? 在竞争激烈的大环境下&#xff0c…

【回眸】Tessy集成测试软件使用指南(一)新手使用篇

前言 这个专栏的文章前4篇都在记录如何使用Tessy进行单元测试,集成测试需要有一定单元测试基础,且做集成测试之前,需要做好单元测试,否则将会大幅增加软件单元代码纠错的代价。集成测试所花费的时间通常远远超过单元测试。如果直…

ROS 快速入门教程02

5. Node 节点 以智能手机为例,当我们使用智能手机的某个功能时,大多时候在使用手机的某个APP。同样当我们使用ROS的某个功能时,使用的是ROS的某一个或者某一些节点。 虽然每次我们只使用ROS的某一个或者某一些节点,但我们无法下…

vue+django+LSTM微博舆情分析系统 | 深度学习 | 食品安全分析

文章结尾部分有CSDN官方提供的学长 联系方式名片 文章结尾部分有CSDN官方提供的学长 联系方式名片 关注B站,有好处! 编号: D031 LSTM 架构:vuedjangoLSTMMySQL 功能: 微博信息爬取、情感分析、基于负面消极内容舆情分析…

HCIP实验二(OSPF网络配置与优化)

一.拓扑图与题目 1.R5为ISP,其上只能配置IP地址; R5与其他所有直连设备间均使用公有IP;环回地址为100.1.1.1/3 2.R4设备为企业出口路由器 3.整个0SPF环境IP基于172.16.0.0/16划分 4.所有设备均可访问R5的环回; 5.减少LSA的更新里,加快收敛&#xff0…

K8S的service详解

一。service的介绍 在K8S中,pod是访问应用程序的载体,我们可以通过pod的ip来访问应用程序,但是pod的ip地址不是固定的,这也意味着不方便直接采用pod的ip对服务进行访问,为了解决这个问题,K8S提供了service…

数据结构初阶:二叉树(四)

概述:本篇博客主要介绍链式结构二叉树的实现。 目录 1.实现链式结构二叉树 1.1 二叉树的头文件(tree.h) 1.2 创建二叉树 1.3 前中后序遍历 1.3.1 遍历规则 1.3.1.1 前序遍历代码实现 1.3.1.2 中序遍历代码实现 1.3.1.3 后序遍历代…

配置Intel Realsense D405驱动与ROS包

配置sdk使用 Ubuntu20.04LTS下安装Intel Realsense D435i驱动与ROS包_realsense的驱动包-CSDN博客 中的方法一 之后不通过apt安装包,使用官方的安装步骤直接clone https://github.com/IntelRealSense/realsense-ros/tree/ros1-legacy 从这一步开始 执行完 这一步…

【最新版】沃德代驾源码全开源+前端uniapp

一.系统介绍 基于ThinkPHPUniapp开发的代驾软件。系统源码全开源,代驾软件的主要功能包括预约代驾、在线抢单、一键定位、在线支付、车主登记和代驾司机实名登记等‌。用户可以通过小程序预约代驾服务,系统会估算代驾价格并推送附近代驾司机供用户选择&…

Linux:权限相关问题

文章目录 shell命令以及运行的原理Linux权限执行权限更改目录权限缺省权限粘滞位 shell命令以及运行的原理 操作系统分为内核和外壳程序,xshell是外壳程序,外壳程序包括我们windows桌面上的图形化界面,本质都是翻译给核心处理,再显…

AI数字人:元宇宙舞台上的闪耀新星(7/10)

摘要:AI数字人作为元宇宙核心角色,提升交互体验,推动内容生产变革,助力产业数字化转型。其应用场景涵盖虚拟社交、智能客服、教育、商业营销等,面临技术瓶颈与行业规范缺失等挑战,未来有望突破技术限制&…

【Linux】冯诺依曼体系结构及操作系统架构图的具体剖析

目录 一、冯诺依曼体系结构 1、结构图 2、结构图介绍: 3、冯诺依曼体系的数据流动介绍 4、为什么在该体系结构中要存在内存? 二、操作系统架构图介绍 1、操作系统架构图 2、解析操作系统架构图 3、为什么要有操作系统? 前些天发现了一…

算法训练营第一天|704.二分查找、27.移除元素、977.有序数组的平方

数组理论基础 1.数组是存放在连续内存空间上的相同类型数据的集合。 2.数组的元素是不能删除的,只能覆盖。 3.不同语言不一样,在C中,二维数组是连续分布的 704.二分查找 题目 思路与解法 第一想法: 简单的二分查找&#xff0c…

c++ 互斥锁

为练习c 线程同步,做了LeeCode 1114题. 按序打印: 给你一个类: public class Foo {public void first() { print("first"); }public void second() { print("second"); }public void third() { print("third"…

中波红外相机的应用领域及介绍

科技日新月异,无人机技术在众多领域已显露其卓越性能。当中波红外相机与无人机携手合作,安防视频监控和精细巡检便迎来了颠覆性的变革。本文旨在深入剖析无人机搭载中波红外相机的技术优势、广阔应用前景及实际案例,以此彰显其不可估量的潜力…