NLP入门 - 基于Word Embedding + LSTM的古诗生成器

article/2025/9/9 8:23:05

一共实现三个功能:

1. 续写五言诗

2. 续写七言诗

3. 写五言藏头诗

之前用这个做Intro to Computer Science的期末项目折腾太久,不想赘述,内容介绍及实现方法可参考期末presentation的slides:

https://docs.google.com/presentation/d/1DFy3VwAETeqK0QFsokeBpDwyVkMavOjpckQKpc6XPTI/edit#slide=id.gb037c6e317_2_312

训练数据来源:

https://github.com/BraveY/AI-with-code/tree/master/Automatic-poem-writing

 

五言诗数据预处理(七言类似,不再贴代码):

import os
import numpy as np
import numpy as npfrom copy import deepcopyall_data = np.load("tang.npz", allow_pickle=True)
dataset = all_data["data"]
word2ix = all_data["word2ix"].item()
ix2word = all_data["ix2word"].item()
print(len(word2ix))
print(dataset.shape)
l = dataset.tolist()
poems = [[None for i in range(125)] for j in range(57580)]
for i in range(57580):for j in range(125):poems[i][j] = ix2word[l[i][j]]data = list()
for i in range(57580):s = 0e = 0for ele in poems[i]:if ele == '<START>':s += 1if ele == '<EOP>':e += 1if s == 1 and e == 1:st = poems[i].index('<START>')ed = poems[i].index('<EOP>')if (ed - st - 1) % 6 == 0 and poems[i][st + 6] == ',' and (ed - st - 1) == 48:# 五言诗,每诗4句line = poems[i][st + 1:ed]for j in range(0, len(line), 24):cur = line[j:j + 24]if cur[5] == ',' and cur[11] == '。' and cur[17] == ',' and cur[23] == '。':data.append(cur)# for ele in data:
# print(ele)
print(len(data))
t = list()
for i, line in enumerate(data):print(i, line)words = line[0:5] + line[6:11] + line[12:17] + line[18:23]nums = [word2ix[words[i]] for i in range(len(words))]t.append(nums)t = np.array(t)
print(t.shape)t = t[77:]labels = deepcopy(t)
for i in range(29696):for j in range(20):if j < 19:labels[i][j] = labels[i][j+1]else:labels[i][j] = 0
np.save("train_x.npy", t)
np.save("train_y.npy", labels)

 

 

五言诗训练(七言类似,不再贴代码):

# -*- coding: utf-8 -*-
"""54.ipynbAutomatically generated by Colaboratory.Original file is located athttps://colab.research.google.com/drive/1ZdPq71-40K5tGK__OPcUe84mfopWwJP-
"""import numpy as np
import matplotlib.pyplot as pltimport torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as Ffrom copy import deepcopyVOCAB_SIZE = 8293all_data = np.load("/content/drive/My Drive/Deep Learning/LSTM - Chinese Poem Writer/tang.npz", allow_pickle=True)
word2ix = all_data["word2ix"].item()
ix2word = all_data["ix2word"].item()class Model(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim):super(Model, self).__init__()self.hidden_dim = hidden_dimself.embedding = nn.Embedding(vocab_size, embedding_dim)self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, num_layers=3)self.out = nn.Linear(self.hidden_dim, vocab_size)def forward(self, input, hidden=None):seq_len, batch_size = input.shapeif hidden is None:h_0 = torch.zeros(3, batch_size, self.hidden_dim).cuda()c_0 = torch.zeros(3, batch_size, self.hidden_dim).cuda()else:h_0, c_0 = hiddenembed = self.embedding(input)# embed_size = (seq_len, batch_size, embedding_dim)output, hidden = self.lstm(embed, (h_0, c_0))# output_size = (seq_len, batch_size, hidden_dim)output = output.reshape(seq_len * batch_size, -1)# output_size = (seq_len * batch_size, hidden_dim)output = self.out(output)# output_size = (seq_len * batch_size, vocab_size)return output, hiddentrain_x = np.load("/content/drive/My Drive/Deep Learning/LSTM - Chinese Poem Writer/train_x_54.npy")
train_y = np.load("/content/drive/My Drive/Deep Learning/LSTM - Chinese Poem Writer/train_y_54.npy")id = 253print(train_x.shape)
print(train_y.shape)
print(train_x[id])
print(train_y[id])
a = [None for i in range(20)]
b = [None for j in range(20)]
for j in range(20):a[j] = ix2word[train_x[id][j]]b[j] = ix2word[train_y[id][j]]
b[19] = 'Null'
print(a)
print(b)class PoemWriter():def __init__(self):self.model = Model(VOCAB_SIZE, 64, 128)self.lr = 1e-3self.epochs = 0self.seq_len = 20self.batch_size = 128self.opt = optim.Adam(self.model.parameters(), lr=self.lr)def train(self, epochs):self.epochs = epochsself.model = self.model.cuda()criterion = nn.CrossEntropyLoss()all_losses = []for epoch in range(self.epochs):print("Epoch:", epoch + 1)total_loss = 0for i in range(0, train_x.shape[0], self.batch_size):# print(i, i + self.batch_size)cur_x = torch.from_numpy(train_x[i:i + self.batch_size])cur_y = torch.from_numpy(train_y[i:i + self.batch_size])cur_x = torch.transpose(cur_x, 0, 1).long()cur_y = torch.transpose(cur_y, 0, 1).long()cur_y = cur_y.reshape(self.seq_len * self.batch_size, -1).squeeze(1)cur_x, cur_y = cur_x.cuda(), cur_y.cuda()pred, _ = self.model.forward(cur_x)loss = criterion(pred, cur_y)self.opt.zero_grad()loss.backward()self.opt.step()total_loss += loss.item()print("Loss:", total_loss)all_losses.append(total_loss)self.model = self.model.cpu()plt.plot(all_losses, 'r')def write(self, string="空山新雨后"):inp = []for c in string:inp.append(word2ix[c])inp = torch.from_numpy(np.array(inp)).unsqueeze(1).long()# print(inp.shape, inp)tmp = torch.zeros(15, 1).long()inp = torch.cat([inp, tmp], dim=0)inp = inp.cuda()self.model = self.model.cuda()for tim in range(15):pred, _ = self.model.forward(inp)pred = torch.argmax(pred, dim=1)inp[tim + 5] = pred[tim + 4]ans = list()for i in range(20):ans.append(ix2word[inp[i].item()])out = ""for i in range(20):out += ans[i]if i == 4 or i == 14:out += ','if i == 9 or i == 19:out += '。'print(out)torch.cuda.get_device_name()Waner = PoemWriter()Waner.train(1000)Waner.write()torch.save(Waner.model, '/content/drive/My Drive/Deep Learning/LSTM - Chinese Poem Writer/model_54.pkl')while True:st = input()if st == "-1":breakWaner.write(st)

最终模型:

import numpy as np
import matplotlib.pyplot as pltimport torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as Ffrom copy import deepcopyVOCAB_SIZE = 8293all_data = np.load("tang.npz", allow_pickle=True)
word2ix = all_data["word2ix"].item()
ix2word = all_data["ix2word"].item()class Model(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim):super(Model, self).__init__()self.hidden_dim = hidden_dimself.embedding = nn.Embedding(vocab_size, embedding_dim)self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, num_layers=3)self.out = nn.Linear(self.hidden_dim, vocab_size)def forward(self, input, hidden=None):seq_len, batch_size = input.shapeif hidden is None:h_0 = torch.zeros(3, batch_size, self.hidden_dim).cuda()c_0 = torch.zeros(3, batch_size, self.hidden_dim).cuda()else:h_0, c_0 = hiddenembed = self.embedding(input)# embed_size = (seq_len, batch_size, embedding_dim)output, hidden = self.lstm(embed, (h_0, c_0))# output_size = (seq_len, batch_size, hidden_dim)output = output.reshape(seq_len * batch_size, -1)# output_size = (seq_len * batch_size, hidden_dim)output = self.out(output)# output_size = (seq_len * batch_size, vocab_size)return output, hiddenclass Writer():def __init__(self):self.model_74 = torch.load('model_74.pkl')self.model_54 = torch.load('model_54.pkl')# self.lr = 1e-3self.epochs = 0self.seq_len = 28self.batch_size = 128# self.opt = optim.Adam(self.model.parameters(), lr=self.lr)def write_74(self, string="锦瑟无端五十弦"):inp = []for c in string:inp.append(word2ix[c])inp = torch.from_numpy(np.array(inp)).unsqueeze(1).long()# print(inp.shape, inp)tmp = torch.zeros(21, 1).long()inp = torch.cat([inp, tmp], dim=0)inp = inp.cuda()self.model_74 = self.model_74.cuda()for tim in range(21):pred, _ = self.model_74.forward(inp)pred = torch.argmax(pred, dim=1)inp[tim + 7] = pred[tim + 6]ans = list()for i in range(28):ans.append(ix2word[inp[i].item()])out = ""for i in range(28):out += ans[i]if i == 6 or i == 20:out += ','if i == 13 or i == 27:out += '。'return outdef write_54(self, string="空山新雨后"):inp = []for c in string:inp.append(word2ix[c])inp = torch.from_numpy(np.array(inp)).unsqueeze(1).long()# print(inp.shape, inp)tmp = torch.zeros(15, 1).long()inp = torch.cat([inp, tmp], dim=0)inp = inp.cuda()self.model_54 = self.model_54.cuda()for tim in range(15):pred, _ = self.model_54.forward(inp)pred = torch.argmax(pred, dim=1)inp[tim + 5] = pred[tim + 4]ans = list()for i in range(20):ans.append(ix2word[inp[i].item()])out = ""for i in range(20):out += ans[i]if i == 4 or i == 14:out += ','if i == 9 or i == 19:out += '。'return outdef acrostic(self, string="为尔心悦"):inp = torch.zeros(20, 1).long()inp = inp.cuda()self.model_54 = self.model_54.cuda()for i in range(20):if i == 0 or i == 5 or i == 10 or i == 15:inp[i] = word2ix[string[i // 5]]else:inp[i] = pred[i - 1]pred, _ = self.model_54.forward(inp)pred = torch.argmax(pred, dim=1)ans = list()for i in range(20):ans.append(ix2word[inp[i].item()])out = ""for i in range(20):out += ans[i]if i == 4 or i == 14:out += ','if i == 9 or i == 19:out += '。'return outdef task(self, string):l = len(string)try:if l == 4:return self.acrostic(string)elif l == 5:return self.write_54(string)elif l == 7:return self.write_74(string)except:return "I don't know how to write one...QAQ"'''
a = torch.ones(5)
a[1] = 4
print(a)
'''

测试样例:

 

虽然韵脚未专门处理,大多不太对,但是能学到一些意象并营造一定意境,如果用更好的word embedding可能会有更好的performance(目前的embedding为pytorch随机生成)。


http://chatgpt.dhexx.cn/article/ZFmdWHE1.shtml

相关文章

基于古诗词的名字生成器

数据集 因为数据量庞大&#xff0c;使用本地的 CSV 数据进行测试。 后续改进 CSV 文件保存到 mongodb 数据库&#xff0c;便于聚合查询。 数据分词 我们需要一个分词器将这些数据进行分词&#xff0c;用到的是 Golang 版的 jieba 库如下&#xff1a; "github.com/go-e…

基于LSTM + keras 的古诗生成器

1.语料准备&#xff1a;包含 5.5 万首唐诗、26 万首宋诗、2.1 万首宋词和其他古典文集。诗人包括唐宋两朝近 1.4 万古诗人&#xff0c;和两宋时期 1.5 千古词人。数据来源于互联网。每行一首诗&#xff0c;标题在预处理的时候已经去掉了。2.模型参数配置&#xff1a;预先定义模…

唐诗生成器

使用唐诗语料库&#xff0c;经过去噪预处理、分词、生成搭配、生成主题等过程&#xff0c;生成唐诗。 csdn下载地址&#xff1a;http://download.csdn.net/detail/lijiancheng0614/9840952 github上repository地址&#xff1a;https://github.com/lijiancheng0614/poem_genera…

JQuery中的val()函数

JQuery中的val()函数相当于Javascript中的value属性&#xff0c;可以用来设置和获取元素的值。 下面用一个简单的邮箱登陆界面来举个例子&#xff1a; 在默认情况下邮箱的地址输入框和密码输入框都有相应的提示。 要求&#xff1a;当鼠标聚焦在邮箱地址输入框时&#xff0c;提…

【Python】sklearn中的cross_val_score()函数参数

sklearn 中的cross_val_score函数可以用来进行交叉验证&#xff0c;因此十分常用&#xff0c;这里介绍这个函数的参数含义。 sklearn.cross_validation.cross_val_score(estimator, X, yNone, scoringNone, cvNone, n_jobs1, verbose0, fit_paramsNone, pre_dispatch‘2*n_job…

sklearn交叉验证函数cross_val_score用法及参数解释

文章目录 一 、使用示例二、参数含义三、常见的scoring取值1.分类、回归和聚类scoring参数选择2.f1_micro和f1_macro区别3.负均方误差和均方误差 一 、使用示例 import numpy as np from sklearn.model_selection import train_test_split from sklearn import svm from sklea…

java val_Java中是否有val()函数?

慕尼黑的夜晚无繁华 很少有实际用例能够评估String作为Java代码的一个片段是必要的或可取的。也就是说&#xff0c;询问如何做到这一点实际上是XY问题&#xff1a;你实际上有一个不同的问题&#xff0c;可以用不同的方法来解决。先问问自己&#xff0c;这是怎么回事String你想要…

c语言val函数用法,函数VAL()什么意思怎么用啊?/

满意答案 Dickyshe 2013.03.22 采纳率&#xff1a;50% 等级&#xff1a;12 已帮助&#xff1a;12551人 将一个数据行变量转换成数字长整型变量如 text1.text "10" text2.text "11" text3.text text1.texttext2.text 按道理应该得到10&#xff0b;11(…

mysql中val是什么意思_val是什么函数

val是将由数字符号组成的字符型数据转换成相应的数值型数据的函数&#xff0c;其语法是“Val(S,V,Code)”&#xff0c;若字符串内出现非数字字符&#xff0c;那么只转换非数字字符前面的部分&#xff1b;若字符串的首字符不是数字符号&#xff0c;则返回数值零&#xff0c;但忽…

抽象方法的访问修饰符

抽象方法不能使用private修饰符&#xff0c;也不宜使用默认修饰符&#xff08;default&#xff09; &#xff08;1&#xff09;如果使用private修饰符 public abstract class SuperClass {/** The abstract method test in type SuperClass can only set a visibility modifi…

Java之访问修饰符

1.访问修饰符 java提供四种访问修饰符&#xff0c;用于控制方法和属性&#xff08;成员变量&#xff09;的访问权限 四种分别是 公开级别&#xff1a;public&#xff0c;对外公开 受保护级别&#xff1a;protected&#xff0c;对子类和同一个包的类公开 默认级别&#xff…

Java-访问修饰符

目录 一、private(私有权限) 二、default(默认权限) 三、protected(受保护权限) ​四、public(公共权限) 五、总结 示意图 private(default)protectedpublic本类YesYesYesYes同包的类NoYesYesYes不同包的父子关系NoNoYesYes不同包的非父子关系NoNoNoYes 注意点&#xf…

C# 的访问修饰符

访问修饰符的作用域分为三种&#xff1a; 类的访问修饰符方法的访问修饰符属性的访问修饰符 访问修饰符主要分为&#xff1a; public (公共的)protected (保护的)internal (内部的)private (私有的) 一&#xff0c;类的访问修饰符 默认的是 internal 二&#xff0c;方法的…

c# 访问修饰符

C# 中提供了 6 种访问修饰符&#xff1a;public、private、protected、internal、protected internal、private protected。   访问修饰符 在所有的类型和类型成员中都具有可访问性级别&#xff0c;用于控制是否可以从程序集或其他程序集中对其他代码的访问控制。 访问修饰符…

访问控制修饰符

定义理解 对类、变量、方法、接口的访问范围限制。一.类修饰符 1.直接声明 class 类名{ }&#xff0c;即不加修饰符的时候 这种情况下&#xff0c;只可访问同一包中的类&#xff0c;不可访问不同包中的类。2.public&#xff0c;公共类的修饰符 这种情况下&#xff0c;既可访…

Java访问修饰符全面详细介绍

Java中有四种权限修饰符&#xff0c;其在同一项目中所对应的访问权限如下&#xff1a; 注意:默认不写即是default&#xff0c;而不是自己还特意加上default关键字。 Java语言有4种访问权限修饰符&#xff0c;下面按照权限从小到大的顺序对4种访问权限分别介绍&#xff1a; 1、…

Java基础:Java中四种访问修饰符

一、背景。 这篇文章主要介绍了Java中四种访问修饰符详细教程,本文通过图文并茂的形式给大家介绍的非常详细&#xff0c;对大家的学习或工作具有一定的参考借鉴价值&#xff0c;需要的朋友可以参考下。放假在家里休息&#xff0c;闲来无事&#xff0c;想巩固巩固自己的基础知识…

C#的5种访问修饰符

C#的5种访问修饰符 1、public 公共的 public 允许一个类将其成员变量和成员函数暴露给其他的函数和对象。任何公有成员都可以被外部的类访问。 class Person {public string name;public int age;public void SayHello(){Console.WriteLine("大家好&#xff0c;我叫{0}&a…

C#访问修饰符

C#中的访问修饰符有以下六种&#xff1a; public&#xff1a;同一程序集和和引用该程序集的所有代码都可访问。public成员可访问级别由该类型本身的级别决定。private&#xff1a;只有同一类中的成员可以访问protected&#xff1a;同一类和派生类中的代码可以访问internal&…