【Text2SQL】WikiSQL 数据集与 Seq2SQL 模型

news2025/5/29 8:13:42

论文:Seq2SQL: Generating Structured Queries from Natural Language using Reinforcement Learning

⭐⭐⭐⭐⭐

ICLR 2018

Dataset: github.com/salesforce/WikiSQL

Code:Seq2SQL 模型实现

一、论文速读

本文提出了 Text2SQL 方向的一个经典数据集 —— WikiSQL,同时提出了一个模型 Seq2SQL,用于把自然语言问句转为 SQL。

WikiSQL 数据集中的 SQL 形式较为简单,不包括排序(order by)、分组(group by)、子查询等其他复杂操作。根据这种简单的形式,本文的 Seq2SQL 模型针对一个 table 和一个 question,预测出 SELECT 部分、Aggregation 部分和 WHERE 部分,并将其构造成一个 SQL 语句。下图展示了一个示例:

在这里插入图片描述

Seq2SQL 基于 Augmented Pointer Network 来实现,下面先介绍一下这个网络结构,然后再介绍基于此来实现 Seq2SQL 模型。

二、Augmented Pointer Network(增广指针网络)

Augmented Pointer Network 能够从输入序列中选择 token 并逐个 token 生成输出序列。

对于一个 example,输入序列 x x x 是由"table 的列名"、“SQL 词汇表”、"question"三者用特殊分隔符拼接起来的序列:

在这里插入图片描述

比如在前面图片的示例中,列名 token 包括 “Pick”、“#”、“CFL” 等等组成,question token 包括 “How”、“many”、“CFL” 等等,SQL 词汇表包括 “SELECT”、“WHERE”、“COUNT”、“MIN” 等等。

这个网络首先对 input sequence x x x 做 word embedding,然后输入给两层的 Bi-LSTM 做编码得到 h e n c h^{enc} henc,其中 input 的第 i 个 token 的编码是 h t e n c h_t^{enc} htenc,这样每个 token 经过编码都变成了一个 vector。

解码器部分使用双层的单向 LSTM,每一步生成一个 token。具体生成方式是:使用上一步生成的 token y s − 1 y_{s-1} ys1 作为输入,输出一个 state g s g_s gs,然后拿 g s g_s gs 与 input sequence 的每个位置 t 的 h t h_t ht 做计算得到一个标量的注意力分数 α s , t p t r \alpha_{s,t}^{ptr} αs,tptr,选择分数最高的对应的输入 token 作为生成的下一个 token。其中注意力分数的计算公式如下:

20240518155338

三、Seq2SQL 模型

虽然可以直接训练 Augmented Pointer Network 让他生成 SQL 序列作为结果,但是这没有利用 SQL 本身固有的结构。本论文固定 SQL 的结构由三部分组成:SELECT、WHERE 和 Aggregation,并训练三个组件来分别生成这三部分:

在这里插入图片描述

3.1 Aggregation Classifier

他就是一个 classifier,最终输出一个 softmax 计算后的分布,从 NULLMAXMINCOUNTSUMAVG 中做分类,NULL 表示没有 aggregation 操作。其 loss L a g g L^{agg} Lagg 使用 cross entropy 来计算。

比如,“How many” 类型的 question 往往被分类为 COUNT

3.2 SELECT column prediction

SELECT column prediction 是一个匹配问题,这里使用指针网络的思想来解决:输入列名序列和 question 的拼接,输出与 question 最匹配的一个 column。

首先使用 LSTM 对每一列进行编码,column j j j 对应一个 vector e j c e_j^c ejc,然后对 input x x x 编码出一个 vector κ s e l \kappa^{sel} κsel,然后使用 MLP,计算 input representation κ s e l \kappa^{sel} κsel 与每一个 column j 的分数 α j s e l \alpha^{sel}_{j} αjsel,之后使用 softmax 对分数进行归一化:

  • 训练时,使用交叉熵损失 L s e l L^{sel} Lsel 来训练该模块
  • 预测时,选分数最大的 column 作为预测结果

对于输入 x x x 编码为 input representation 和计算分数的详细信息可以参考论文和代码实现

3.3 WHERE Clause

这里使用类似于 Augmented Pointer Network 的 pointer decoder 来训练这一模块。但是使用 cross entropy 有一个限制:两个 WHERE 条件可以被交换并产生相同结果。但两个顺序不同的 WHERE 会被 cross entropy 错误地惩罚,比如 year>18 and male=1male=1 and year>18 是等价的,但由于 cross entropy 是精确匹配 tokens,导致这个结果会被计算损失。

这里使用强化学习(RL)来训练, q ( y ) q(y) q(y) 是生成的查询, q g q_g qg 是真实查询,奖励函数的定义如下:

20240518171120

并根据此奖励函数计算出 loss L w h e L^{whe} Lwhe

3.4 Seq2SQL 的训练

设置一个混合损失函数 L = L a g g + L s e l + L w h e L = L^{agg} + L^{sel} + L^{whe} L=Lagg+Lsel+Lwhe,并使用梯度下降来最小化该 loss 从而训练模型。

四、WikiSQL 数据集

该文更重要的一个贡献是提供了一个 WikiSQL 数据集,包含 80654 条样本和 24241 个 schema。这些数据被随机划分为 train、dev 和 test 三个 split。

下面是一个 example:

20240518173309

解释如下:

  • phase: the phase in which the dataset was collected. We collected WikiSQL in two phases.
  • question: the natural language question written by the worker.
  • table_id: the ID of the table to which this question is addressed.
  • sql: the SQL query corresponding to the question. This has the following subfields:
    • sel: the numerical index of the column that is being selected. You can find the actual column from the table.
    • agg: the numerical index of the aggregation operator that is being used. You can find the actual operator from Query.agg_ops in lib/query.py.
    • conds: a list of triplets (column_index, operator_index, condition) where:
      • column_index: the numerical index of the condition column that is being used. You can find the actual column from the table.
      • operator_index: the numerical index of the condition operator that is being used. You can find the actual operator from Query.cond_ops in lib/query.py.
      • condition: the comparison value for the condition, in either string or float type.

同时还给出了每个 table 的 schema 和数据部分。

五、评估指标

  • N N N:数据集的样本总数
  • N e x N_{ex} Nex:运行生成的 SQL 后,得到正确结果的样本数
  • N l f N_{lf} Nlf:生成的 SQL 与 ground-truth SQL 字符串完全精确匹配的样本数

由此提出两个指标:

  • A C C e x = N e x / N ACC_{ex} = N_{ex} / N ACCex=Nex/N执行精度指标,如果生成的 SQL 与 ground-truth SQL 的执行结果相同,那就算作正确。存在一个缺点:如果构造一个错误的 SQL 但执行结果正确,依然被算作正确
  • A C C l f = N l f / N ACC_{lf} = N_{lf} / N ACClf=Nlf/N逻辑形式的精确指标,如果生成的 SQL 与 ground-truth SQL 完全匹配,才被算作正确。存在一个缺点:两个等价但写法不同的 SQL 会被算作错误

六、总结

这篇论文给出了一个 WikiSQL 数据集,并提出了 Text2SQL 的一个解决方案以及评价指标。

但是很明显,该方案存在不少缺点,之后的方案会继续改进。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1695097.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Amesim应用篇-电芯等效电路模型标定

前言 为了使计算模型更加准确,在有电芯实验测试数据的情况下,依据现有的实验数据对Amesim中的电池等效电路模型进行标定。标定的目的是为了获得更加符合项目实际情况的电芯等效电路模型,标定完的电芯可以用于搭建PACK模型,也可以用于其他虚拟实验。本文以充电标定为例,进…

ideal 启动 多个 相同 工程

spring相同项目在idea多次运行 点击IDEA右上角项目的隐藏下拉框,出现下拉列表,点击Edit Configurations 弹出Run/Debug Configuration对话框,勾选Allow parallel run

vue实战 ---- 社交媒体---黑马头条项目

vue基础 1.介绍 为什么会有Vuex ? ​ Vuex 是一个专为 Vue.js 应用程序开发的状态管理模式。它采用集中式存储管理应用的所有组件的状态,并以相应的规则保证状态以一种可预测的方式发生变化。 vuex是采用集中式管理组件依赖的共享数据的一个工具,可以解…

21-信号集处理函数

屏蔽信号集 屏蔽某些信号 手动自动 未处理信号集 信号如果被屏蔽,则记录在未处理信号集中 非实时信号(1~31),不排队,只留一个实时信号(34~64),排队,保留全部 信号集…

7. Spring MVC面试题汇总

Java全栈面试题汇总目录-CSDN博客 1. 什么是Spring MVC,简单介绍下你对Spring MVC的理解? Spring MVC是一个基于Java的实现了MVC设计模式的请求驱动类型的轻量级Web框架,通过把Model,View,Controller分离,将web层进…

Jenkins安装部署--图文详细

Jenkins–从入门到入土 文章目录 Jenkins--从入门到入土一、Jenkins安装部署1、什么是Jenkins?2、Jenkins在开发过程中所属位置3、安装硬件环境和知识储备4、安装4.1、下载war启动4.2、Docker启动4.3、windows使用驱动安装 5、使用插件自定义 Jenkins6、创建第一个管理员用户 …

1+x(Java)中级题库易混淆理论题

<ALL表示小于最小 小于最高等同于小于ANY 使用USING子句&#xff0c;在使用连接字段时&#xff0c;都不能在前面加上表的前缀&#xff0c;因为此时这个字段已经是连接字段&#xff0c;不再属于某个单独的表。 数据库提供的自动将提供的数据类型数据转换为期望的数据类…

SpringBoot3笔记(一)SpringBoot3-核心特性

快速学习 SpringBoot 看官方文档&#xff1a; Spring Boot Reference Documentation 计划三天学完 笔记&#xff1a;https://www.yuque.com/leifengyang/springboot3 代码&#xff1a;https://gitee.com/leifengyang/spring-boot-3 一、SpringBoot3 - 快速入门 1.1 简介 …

深入解析:如何高效地更新Python字典

新书上架~&#x1f447;全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我&#x1f446;&#xff0c;收藏下次不迷路┗|&#xff40;O′|┛ 嗷~~ 目录 一、引言 二、修改字典中的值 三、向字典中添加键值对 四、更新字典的两种方法总结 五、…

Re72:读论文 XLM Cross-lingual Language Model Pretraining

诸神缄默不语-个人CSDN博文目录 诸神缄默不语的论文阅读笔记和分类 论文名&#xff1a;Cross-lingual Language Model Pretraining 模型简称&#xff1a;XLM ArXiv地址&#xff1a;https://arxiv.org/abs/1901.07291 这是2019年NeurIPS的论文&#xff0c;主要做到就是跨语言…

第十一届蓝桥杯物联网试题(国赛)

国赛题目看着简单其实还是挺复杂的&#xff0c;所以说不能掉以轻心&#xff0c;目前遇到的问日主要有以下几点&#xff1a; 本次题主要注重的是信息交互&#xff0c;与A板通信的有电脑主机和B板&#xff0c;所以处理好这里面的交互过程很重要 国赛中避免不了会收到其他选手的…

【Linux设备驱动】1.字符设备驱动程序框架及相关结构体

目录 程序总体框架模块加载函数模块卸载函数具体操作函数 相关结构体cdev结构体file_oparations结构体 设备号分配设备号注销设备号创建设备文件 程序总体框架 /* 包含相关头文件 */ #include <linux/module.h> #include <linux/fs.h> #include <linux/init.h&…

智慧校园的建设思路

智慧校园建设的一个主要目的就是要打破学校内的信息孤岛&#xff0c;其核心是在人、流程和信息三个层面的全面整合。智慧校园应该能够为全校师生员工及校外用户提供统一的、一站式的服务渠道&#xff1b;能够将学校各种业务流程连接起来&#xff0c;实现各种应用系统的互联互通…

设计新境界:大数据赋能UI的创新美学

设计新境界&#xff1a;大数据赋能UI的创新美学 引言 随着大数据技术的蓬勃发展&#xff0c;它已成为推动UI设计创新的重要力量。大数据不仅为界面设计提供了丰富的数据资源&#xff0c;还赋予了设计师以全新的视角和工具来探索美学的新境界。本文将探讨大数据如何赋能UI设计…

linux系统——终止进程命令

linux进程&#xff0c;有所谓进程树的概念&#xff0c;在此之上&#xff0c;有父进程与子进程 pgrep进程名可以查看进程信息 同时&#xff0c;此命令也可以使用参数进行调节 关于kill有一系列命令参数 echo $?可以输出上次命令执行的情况

【Linux】写时拷贝技术COW (copy-on-write)

文章目录 Linux写时拷贝技术(copy-on-write)进程的概念进程的定义进程和程序的区别PCB的内部构成 程序是如何被加载变成进程的&#xff1f;写时复制&#xff08;Copy-On-Write, COW&#xff09;写时复制机制的原理写时拷贝的场景 fork与COWvfork与fork Linux写时拷贝技术(copy-…

算法打卡 Day9(字符串KMP 算法)-实现 strStr+ 重复的子字符串

KMP 算法 KMP 算法解决的是字符串匹配的问题&#xff0c;其经典思想是&#xff1a;当出现的字符串不匹配时&#xff0c;可以记录一部分之前已经匹配的文本内容&#xff0c;利用这些信息避免从头再去做匹配。 前缀表 next 数组就是一个前缀表。前缀表是用来回退的&#xff0c…

秋招突击——算法——模板题——区间DP——合并石子

文章目录 题目内容思路分析实现代码分析与总结 题目内容 思路分析 基本思路&#xff0c;先是遍历区间长度&#xff0c;然后再是遍历左端点&#xff0c;最后是遍历中间的划分点&#xff0c;将阶乘问题变成n三次方的问题 实现代码 // 组合数问题 #include <iostream> #in…

如何在Windows 11上清除缓存,这里提供几种方法

序言 为了提高电脑的性能并保持整洁,你应该定期清除电脑上的各种缓存。我们将向你展示如何在Windows 11中做到这一点。 缓存文件是由各种应用程序和服务创建的临时文件。清除这些文件通常不会导致应用程序出现任何问题,因为应用程序会在需要时重新创建这些文件。你也可以将…

【树与图的bfs】

宽度优先遍历 queue<int> q; st[1] true; // 表示1号点已经被遍历过 q.push(1);while (q.size()) {int t q.front();q.pop();for (int i h[t]; i ! -1; i ne[i]){int j e[i];if (!st[j]){st[j] true; // 表示点j已经被遍历过q.push(j);}} } #include <cstdio…