对话机器人

article/2025/9/20 3:40:56

【居然审核不通过……内容低俗,这么高大上的内容,哪里低俗了……】

前面写了一系列的 微信机器人,但还没涉及到自然语言处理(Natural Language Processing, NLP)。今天把这坑填上。本文将基于 Seq2Seq 模型和Little Yellow Chicks 数据集(估计就是这个数据集低俗了),搭建一个简单的对话机器人。

原理

网上介绍 Seq2Seq 的文章太多了,我也不太可能写出更好的讲解,索性就更通俗地介绍吧。

从线性回归说起

先回忆一下线性回归。线性回归是统计学的范畴,关键词是线性回归

线性

线性是说两个变量之间的关系是一次函数关系——图像是直线,每个自变量的最高次项为 1。至少在中学的时候,我们就学过一元一次函数,一般写作 y = k x + b y=kx+b y=kx+b,其中 k、b 是常数,且 k ≠ 0 k\neq0 k=0。通常我们称 y 为因变量,x 为自变量(也叫)。有一元自然有多元;有线性当然也有非线性,这里就不展开了。

回归

回归指的回归分析,是研究自变量与因变量之间数量变化关系的一种分析方法,它主要是通过因变量 y 与影响它的自变量 x 之间的回归模型,衡量自变量 x 对因变量 y 的影响能力的,进而可以用来预测因变量 y 的发展趋势。比如,人的身高和体重一般而已是有对应关系的(胖子别哭),我们便可以通过身高(x)来预测体重(y)。与回归对应还有分类,这里也不展开了。

把前面的线性和回归合起来,便是线性回归了。假设我们通过大量观测,估计出了 y = k x + b y=kx+b y=kx+b k k k b b b,那么,给定一个 x x x,我们将可以预测出一个 y y y;给多个 x x x,便可以预测出多个 y y y;亦即,一一对应的关系。

对于分类,我们也得到一一对应的关系,比如一张图片,我们识别出来要么是狗要么是猫,不存在即是狗又是猫。

这和今天的话题有什么关系呢?好像没啥关系,就是突然想到了……

Seq2Seq

前面提到,无论是回归还是分类,通常我们建立的都是一一对应关系(这里指的是一个输入一个输出)。但对于翻译,一一对应未必是满足需求的了。比如:

中文英文
江枫maple-trees near the river
寒山寺temple on the Cold Mountain

在翻译场景下,通常就是输入一串,然后输出一串,而且通常不是一一对应的。Seq2Seq 则很好地满足这种需求。Seq2Seq 指的是 Sequence to Sequence,序列到序列。

编码器解码器结构

Seq2Seq
上图很直观地描述了 序列序列(Seq2Seq):今晚打老虎 -> Fight tigers tonight

一位翻译,他必须先学习本语言(编码器),同时也要学习目标语言(解码器),通过大量学习,他会形成“语感”(Context)。当来了一个翻译任务,他便能翻译出来。

更多细节,可以参考:Sequence to Sequence Learning with Neural Networks。

由翻译到闲聊

上面的翻译原理,应该是比较好理解的。但这跟闲聊有什么关系呢?

我们想想,闲聊是不是问一句,回答一句?同样是一串到一串嘛!所以,我们可以用同样的模型,将翻译语料替换成闲聊语料,便可以训练出一只对话机器人了。为了方便,直接类比翻译,构造一个输入词表一个输出词表;同时与一般的 NLP 处理不同,闲聊机器人的词表制作不需要去掉停用词。

Talk is cheap. Show me the code.
ShowCode

代码

数据处理

data_processing.py

# -*- coding: utf-8 -*-import io
import json
import logging
import osimport jieba
import tensorflow as tf
import tensorflow_probability as tfp
from tqdm import tqdmdef add_flag(w):return "<bos> " + w + " <eos>"class Data(object):def __init__(self, config) -> None:self.config = configself.seq_path = config["data_path"] + config["dataset"] + ".data"self.conv_path = config["data_path"] + config["dataset"] + ".conv"self.conv_size = os.path.getsize(self.conv_path)self.vacab_path_in = config["data_path"] + config["dataset"] + ".vin"self.vacab_path_out = config["data_path"] + config["dataset"] + ".vout"self.max_length = config["max_length"]self.batch_size = config["batch_size"]self.LOG = logging.getLogger("Data")logging.basicConfig(level=logging.INFO)jieba.setLogLevel(logging.INFO)  # Disable debug infodef create_sequences(self):if os.path.exists(self.seq_path):  # Skip if processed data existsreturnif not os.path.exists(self.conv_path):self.LOG.info("找不到语料文件,请检查路径")exit()self.LOG.info("正在处理语料")with tqdm(total=self.conv_size) as pbar, open(self.conv_path, encoding="utf-8") as fin, open(self.seq_path, "w") as fout:one_conv = ""  # 存储一次完整对话for line in fin:pbar.update(len(line.encode("utf-8")))line = line.strip("\n")if not line:  # Skip empty linecontinue# Refer to dataset format: E M Mif line[0] == self.config["e"]:  # E, end of conversation, save itif one_conv:fout.write(one_conv[:-1] + "\n")one_conv = ""elif line[0] == self.config["m"]:  # M, question or answer, split them with \tone_conv = one_conv + str(" ".join(jieba.cut(line.split(" ")[1]))) + "\t"def create_vacab(self, lang, vocab_path, vocab_size):if os.path.exists(vocab_path):  # Skip if existsreturntokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=vocab_size, oov_token="<UNK>")tokenizer.fit_on_texts(lang)with open(vocab_path, "w", encoding="utf-8") as f:f.write(tokenizer.to_json(ensure_ascii=False))self.LOG.info(f"正在保存: {vocab_path}")def create_vacabularies(self):if os.path.exists(self.vacab_path_in) and os.path.exists(self.vacab_path_out):  # Skip if existsreturnself.LOG.info(f"正在创建字典")lines = io.open(self.seq_path, encoding="UTF-8").readlines()word_pairs = [[add_flag(w) for w in l.split("\t")] for l in lines]input, target = zip(*word_pairs)self.create_vacab(input, self.vacab_path_in, self.config["vacab_size_in"])self.create_vacab(target, self.vacab_path_out, self.config["vacab_size_out"])def tokenize(self, path):# Load tokenizer from filewith open(path, "r", encoding="utf-8") as f:tokenize_config = json.dumps(json.load(f), ensure_ascii=False)tokenizer = tf.keras.preprocessing.text.tokenizer_from_json(tokenize_config)return tokenizerdef process(self):self.create_sequences()self.create_vacabularies()def load(self):self.process()  # Process dataset if not did beforeself.LOG.info("正在加载数据")lines = io.open(self.seq_path, encoding="UTF-8").readlines()word_pairs = [[add_flag(w) for w in l.split("\t")] for l in lines]words_in, words_out = zip(*word_pairs)tokenizer_in = self.tokenize(self.vacab_path_in)tokenizer_out = self.tokenize(self.vacab_path_out)tensor_in = tokenizer_in.texts_to_sequences(words_in)tensor_out = tokenizer_out.texts_to_sequences(words_out)tensor_in = tf.keras.preprocessing.sequence.pad_sequences(tensor_in, maxlen=self.max_length, truncating="post", padding="post")tensor_out = tf.keras.preprocessing.sequence.pad_sequences(tensor_out, maxlen=self.max_length, truncating="post", padding="post")self.steps_per_epoch = len(tensor_in) // self.batch_sizeBUFFER_SIZE = len(tensor_in)dataset = tf.data.Dataset.from_tensor_slices((tensor_in, tensor_out)).shuffle(BUFFER_SIZE)dataset = dataset.batch(self.batch_size, drop_remainder=True)return dataset, tokenizer_in, tokenizer_out

Seq2Seq 模型

seq2seq.py

# -*- coding: utf-8 -*-import logging
import osos.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # Disable Tensorflow debug messageimport jieba
import tensorflow as tf
from tqdm import tqdmfrom data_processing import Data, add_flagclass Encoder(tf.keras.Model):def __init__(self, vocab_size, embedding_dim, enc_units, batch_size):super(Encoder, self).__init__()self.enc_units = enc_unitsself.batch_size = batch_sizeself.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)self.gru = tf.keras.layers.GRU(self.enc_units, return_sequences=True, return_state=True)def call(self, X):X = self.embedding(X)output, state = self.gru(X)return output, stateclass Decoder(tf.keras.Model):def __init__(self, vocab_size, embedding_dim, dec_units, batch_size):super(Decoder, self).__init__()self.batch_size = batch_sizeself.dec_units = dec_unitsself.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)self.gru = tf.keras.layers.GRU(self.dec_units, return_sequences=True, return_state=True)self.fc = tf.keras.layers.Dense(vocab_size)def call(self, X, state, **kwargs):X = self.embedding(X)context = tf.reshape(tf.repeat(state, repeats=X.shape[1], axis=0), (X.shape[0], X.shape[1], -1))X_and_context = tf.concat((X, context), axis=2)output, state = self.gru(X_and_context)output = tf.reshape(output, (-1, output.shape[2]))X = self.fc(output)return X, statedef initialize_hidden_state(self):return tf.zeros((self.batch_size, self.dec_units))class Seq2Seq(object):def __init__(self, config):self.config = configvacab_size_in = config["vacab_size_in"]vacab_size_out = config["vacab_size_out"]embedding_dim = config["embedding_dim"]self.units = config["layer_size"]self.batch_size = config["batch_size"]self.encoder = Encoder(vacab_size_in, embedding_dim, self.units, self.batch_size)self.decoder = Decoder(vacab_size_out, embedding_dim, self.units, self.batch_size)self.optimizer = tf.keras.optimizers.Adam()self.checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, encoder=self.encoder, decoder=self.decoder)self.ckpt_dir = self.config["model_data"]logging.basicConfig(level=logging.INFO)self.LOG = logging.getLogger("Seq2Seq")if tf.io.gfile.listdir(self.ckpt_dir):self.LOG.info("正在加载模型")self.checkpoint.restore(tf.train.latest_checkpoint(self.ckpt_dir))data = Data(config)self.dataset, self.tokenizer_in, self.tokenizer_out = data.load()def loss_function(self, real, pred):loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)mask = tf.math.logical_not(tf.math.equal(real, 0))loss_ = loss_object(real, pred)mask = tf.cast(mask, dtype=loss_.dtype)loss_ *= maskreturn tf.reduce_mean(loss_)@tf.functiondef training_step(self, src, tgt, tgt_lang):loss = 0with tf.GradientTape() as tape:enc_output, enc_hidden = self.encoder(src)dec_hidden = enc_hiddendec_input = tf.expand_dims([tgt_lang.word_index["bos"]] * self.batch_size, 1)for t in range(1, tgt.shape[1]):predictions, dec_hidden = self.decoder(dec_input, dec_hidden)loss += self.loss_function(tgt[:, t], predictions)dec_input = tf.expand_dims(tgt[:, t], 1)step_loss = (loss / int(tgt.shape[1]))variables = self.encoder.trainable_variables + self.decoder.trainable_variablesgradients = tape.gradient(loss, variables)self.optimizer.apply_gradients(zip(gradients, variables))return step_lossdef train(self):writer = tf.summary.create_file_writer(self.config["log_dir"])self.LOG.info(f"数据目录: {self.config['data_path']}")epoch = 0train_epoch = self.config["epochs"]while epoch < train_epoch:total_loss = 0iter_data = tqdm(self.dataset)for batch, (src, tgt) in enumerate(iter_data):batch_loss = self.training_step(src, tgt, self.tokenizer_out)total_loss += batch_lossiter_data.set_postfix_str(f"batch_loss: {batch_loss:.4f}")self.checkpoint.save(file_prefix=os.path.join(self.ckpt_dir, "ckpt"))epoch = epoch + 1self.LOG.info(f"Epoch: {epoch}/{train_epoch} Loss: {total_loss:.4f}")with writer.as_default():tf.summary.scalar("loss", total_loss, step=epoch)def predict(self, sentence):max_length = self.config["max_length"]sentence = " ".join(jieba.cut(sentence))sentence = add_flag(sentence)inputs = self.tokenizer_in.texts_to_sequences([sentence])inputs = [[x for x in inputs[0] if x if not None]]  # Remove None. TODO: Why there're None???inputs = tf.keras.preprocessing.sequence.pad_sequences(inputs, maxlen=max_length, padding="post")inputs = tf.convert_to_tensor(inputs)enc_out, enc_hidden = self.encoder(inputs)dec_hidden = enc_hiddendec_input = tf.expand_dims([self.tokenizer_out.word_index["bos"]], 0)result = ""for _ in range(max_length):predictions, dec_hidden = self.decoder(dec_input, dec_hidden)predicted_id = tf.argmax(predictions[0]).numpy()if self.tokenizer_out.index_word[predicted_id] == "eos":breakresult += str(self.tokenizer_out.index_word[predicted_id])dec_input = tf.expand_dims([predicted_id], 0)return result

程序入口

main.py

#! /usr/bin/env python3
# -*- coding: utf-8 -*-import argparsefrom seq2seq import Seq2Seqif __name__ == "__main__":import argparseparser = argparse.ArgumentParser()parser.add_argument("--mode", "-m", type=str, default="serve", help="Train or Serve")parser.add_argument("--data_path", "-p", type=str, default="dataset/", help="dataset path")parser.add_argument("--log_dir", type=str, default="logs/", help="dataset path")parser.add_argument("--model_data", type=str, default="model_data", help="mode output path")parser.add_argument("--dataset", "-n", type=str, default="xiaohuangji50w", help="Train or Serve")parser.add_argument("--e", type=str, default="E", help="start flag of conversation")parser.add_argument("--m", type=str, default="M", help="start flag of conversation")parser.add_argument("--vacab_size_in", "-i", type=int, default=20000, help="vacabulary input size")parser.add_argument("--vacab_size_out", "-o", type=int, default=20000, help="vacabulary output size")parser.add_argument("--layer_size", type=int, default=256, help="layer size")parser.add_argument("--batch_size", type=int, default=128, help="batch size")parser.add_argument("--layers", type=int, default=2, help="layers")parser.add_argument("--embedding_dim", type=int, default=64, help="embedding dimention")parser.add_argument("--epochs", type=int, default=10, help="epochs")parser.add_argument("--max_length", type=int, default=32, help="max length of input")args, _ = parser.parse_known_args()config = vars(args)seq2seq = Seq2Seq(config)if args.mode.lower() == "train":seq2seq.train()else:while True:msg = input(">>> ")rsp = seq2seq.predict(msg)print(rsp)

效果

训练

python main.py -m train

闲聊

python main.py

训练了 20 个 epochs,就这水平:
在这里插入图片描述

提升方向

上面的结果比较一般,一方面因为语料不多;另一方面,现在这模型也比较粗糙,后续可以从以下几方面进行提升:

  • 使用更多的语料
  • 使用 Bi-LSTM 构造 Encoder
  • 引入 Attention

http://chatgpt.dhexx.cn/article/hkQZ8NGf.shtml

相关文章

关于对话机器人,你需要了解这些技术

对话系统(对话机器人)本质上是通过机器学习和人工智能等技术让机器理解人的语言。它包含了诸多学科方法的融合使用,是人工智能领域的一个技术集中演练营。图1给出了对话系统开发中涉及到的主要技术。 对话系统技能进阶之路 图1给出的诸多对话系统相关技术,从哪些渠道可以…

对话机器人(一)——对话机器人基础知识

对话机器人基础 一、对话机器人分类 1. 知识领域 a. 面向限定领域 只能聊设定好的固定主题。若用户用无关领域挑战机器人&#xff0c;机器人用安全话术回复或结束对话。 b. 面向开放领域 用户不需要有明确的目的或意图。 c. 面向常用问题集 通过检索知识库来回答问题&a…

Typora无法在applist里找到

添加一个desktop文件即可&#xff0c;记得加上%U才能在应用列表里看见

php安装失败,phpcms安装失败怎么办

phpcms安装失败怎么办&#xff1f; 最新版的phpcmsV9安装报错解决 具体报错信息如下&#xff1a;Web-server: Apache PHP版本: PHP/5.2.14 Mysql版本: MySQL 客户端版本: 5.0.90 适用版本: v9 更新日期: phpcms_v9.2.2_UTF8 编码版本: UTF-8 浏览器: maxthon 复现步骤: 正在准备…

android应用程序列表,List列表应用程序-小知识 #103

文章摘要&#xff1a; 1、从设计模式的角度浅谈List列表应用程序开发。 2、列表应用程序开发三要素。控件、数据、适配器。 一、综述&#xff1a; 1、Android中&#xff0c;使用ListView配合Adapter来展示数据列表的例子随处可见。但在实际应用场景中&#xff0c;数据源类型、V…

推荐系统中的Embedding应用

文章目录 1. Word2Vec1.1 Skip-gram 2. Airbnb中的Embedding2.1 用在相似推荐中的List Embedding2.1.1 优化一&#xff1a;Booked Listing as Global Context2.1.2 优化二&#xff1a;Adapting Training for Congregated Search2.1.3 冷启动问题2.1.4 效果评估 2.2 用在搜索推荐…

APP设备数据的特征衍生与模型应用

在信贷风控领域众多维度的数据源中&#xff0c;APP设备数据对于策略规则的开发、模型变量的筛选有着重要的贡献&#xff0c;理由是在当今电子信息化时代&#xff0c;APP数据可以较全面地反映出用户的个人习惯、日常行为等综合信息。因此&#xff0c;金融机构在开展个人信贷产品…

风控建模十二:数据淘金——如何从APP数据中挖掘出有效变量

风控建模十二&#xff1a;数据淘金——如何从APP数据中挖掘出有效变量 1、常识知识2、个例分析3、分布排查 智能手机的诞生改变了人类的生活方式&#xff0c;智能手机所承载的功能日臻完善、强大&#xff0c;人们在衣、食、住、行、工作、生活中面临的方方面面问题&#xff0c;…

2021-03-07 大数据课程笔记 day46

R星校长 机器学习06【机器学习】 主要内容 理解推荐系统处理数据流程。python 文件预处理 Hive 数据。dubbo 服务使用。 学习目标 第一节 推荐系统-数据处理流程 推荐系统数据处理首先是将 Hive 中的用户 app 历史下载表与 app 浏览信息表按照设备 id 进行关联&#xff0c…

java手机应用安装目录_如何获得Android手机的软件安装列表

Android的PackageManager类用于检索目前安装在设备上的应用软件包的信息。你可以通过调用getpackagemanager()得到PackageManager类的一个实例。对查询和操作安装包和相关的权限提供了方法&#xff0c;在下面这个Android的例子中&#xff0c;我们得到了在Android安装的应用程序…

这些信贷数据埋点中不得不知的埋点知识

国庆七天假&#xff0c;就这样飞快结束&#xff0c;似乎感觉还没休息够&#xff0c;再来一个七天都不觉得多多。 经过多年来移动互联网的普及&#xff0c;众多APP已采集到亿级乃至数十亿级别用户在设备端、通话、短信、地址等强变量的数据&#xff0c;伴随着近年来信贷行业高速…

新浪微博开发(五)AppList界面

这是客户端开发部分很重要的一个类&#xff0c;但是在开发之前需要用到有关GridView的知识。 若要临时充充电&#xff0c;请移步&#xff1a;GridView(九宫图)的使用介绍。 下面是AppList类的代码&#xff1a; /* * 用来显示、管理自己的微博账号&#xff0c;包括新浪微博账号…

使用react+redux+react-redux+react-router+axios+scss技术栈从0到1开发一个applist应用

先看效果图 github地址 github仓库 在线访问 初始化项目 #创建项目 create-react-app applist #如果没有安装create-react-app的话&#xff0c;先安装 npm install -g create-react-app 目录结构改造 |--config |--node_modules |--public |--scripts |--src|-----api //…

【无标题】https://e-cloudstore.com/ec/api/applist/index.html#/

这里写自定义目录标题 欢迎使用Markdown编辑器新的改变功能快捷键合理的创建标题&#xff0c;有助于目录的生成如何改变文本的样式插入链接与图片如何插入一段漂亮的代码片生成一个适合你的列表创建一个表格设定内容居中、居左、居右SmartyPants 创建一个自定义列表如何创建一个…

AppList数据处理

本文参考: 风控数据—手机App数据挖掘实践思路 引言 作为移动互联网时代的主要载体,智能手机逐渐成为人们日常生活中不可或缺的一部分,改变着人们的生活习惯。比如,可以用“饿了么”点外卖,“支付宝”可以用来种树,“抖音”可以用来上厕所......强大的App给我们的生活带来…

Faststone capture注册码

转载整理 很好用的图片编辑软件&#xff01; 企业版序列号&#xff1a; name&#xff1a;bluman serial/序列号/注册码&#xff1a;VPISCJULXUFGDDXYAUYF FastStone Capture 注册码 序列号&#xff1a; name/用户名&#xff1a;TEAM JiOO key/注册码&#xff1a;CPCWXRVCZW30H…

FastStone Capture(超级强大的截图、屏幕录制软件)

FastStone Capture是一款体积极其小、功能强悍的屏幕捕捉软件&#xff0c;还有强大的图片编辑、视频录制编辑功能&#xff0c;能够完全满足你截屏、处理图片的要求。FastStone Capture &#xff08;FSCapture&#xff09; 是经典好用的屏幕截图软件&#xff0c;还具有图像编辑和…

截图工具FastStone Capture

文章目录 1 下载安装2 工具使用介绍2.1 截图2.1.1 截图2.1.2 滚动截图 2.2 录屏 FastStone Capture也常被简称为FS Capture&#xff0c;是一款小巧但强悍的软件&#xff0c;集截图、录屏、标尺、取色器等工具于一体。 1 下载安装 官方下载&#xff1a; https://faststone-capt…

FastStone Capture—视频绘制

博客概要 已经写了好些篇博文来介绍FSP了&#xff0c;广告费付一下&#xff1f; 在博主的某篇博文中&#xff0c;粗略介绍了FSP的“屏幕录像”功能&#xff0c;其中“录像编辑”一栏其实没有过多赘述&#xff0c;那么本篇就稍微&#xff0c;真的是稍微&#xff0c;再多介绍一些…

巨好用的截图录屏工具-FastStone Capture

哈喽&#xff0c;大家好呀。想问问小伙伴们平时若是不打开微信、QQ的时候一般是用什么工具截图呢&#xff0c;是微软自带的截图工具&#xff0c;还是其它的呢&#xff1f;今天给大家测试一款非常好用的截图录屏一体轻量级工具哦&#xff0c;也是小编用了好多年的一款工具了&…