使用Keras 构建基于 LSTM 模型的故事生成器(二)

简介: 使用Keras 构建基于 LSTM 模型的故事生成器(二)

Step2:导入数据分析库并进行分析

接下来,我们导入必要的库并且查看数据集。使用的是运行在 TensorFlow 2.0 的 Keras 框架。

from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.layers import Embedding, LSTM, Dense, Dropout, Bidirectional
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import regularizers
import tensorflow.keras.utils as ku
import numpy as np
import tensorflow as tf
import pickle
data=open('stories.txt',encoding="utf8").read()

Step3:使用 NLP 库预处理数据

首先,我们将数据全部转换为小写,并将其按行拆分,以获得一个python语句列表。转换成小写的原因是,同一单词不同大小写,其意义是一样的。例如,“Doctor”和“doctor”都是医生,但模型会对其进行不同的处理。

然后我们将单词进行编码并转化为向量。为每一个单词生成索引属性,该属性返回一个包含键值对的字典,其中键是单词,值是该单词的记号。

# Converting the text to lowercase and splitting it
corpus = data.lower().split("\n")
# Tokenization
tokenizer = Tokenizer()
tokenizer.fit_on_texts(corpus)
total_words = len(tokenizer.word_index) + 1
print(total_words)

下一步将把句子转换成基于这些标记索引的值列表。这将把一行文本(如“frozen grass crunched beneath the steps”)转换成表示单词对应的标记列表。

image.png

然后我们将遍历标记列表,并且使每个句子的长度一致,否则,用它们训练神经网络可能会很困难。主要在于遍历所有序列并找到最长的一个。一旦我们有了最长的序列长度,接下来要做的是填充所有序列,使它们的长度相同。

image.png

同时,我们需要将划分输入数据(特征)以及输出数据(标签)。其中,输入数据就是除最后一个字符外的所有数据,而输出数据则是最后一个字符。

image.png

现在,我们将对标签进行 One-hot 编码,因为这实际上是一个分类问题,在给定一个单词序列的情况下,我们可以从语料库中对下一个单词进行分类预测。

# create input sequences using list of tokens
input_sequences = []
for line in corpus:
   token_list = tokenizer.texts_to_sequences([line])[0]
   for i in range(1, len(token_list)):
       n_gram_sequence = token_list[:i+1]
       input_sequences.append(n_gram_sequence)
# pad sequences
max_sequence_len = max([len(x) for x in input_sequences])
print(max_sequence_len)
input_sequences = np.array(pad_sequences(input_sequences, maxlen=max_sequence_len, padding='pre'))
# create predictors and label
predictors, label = input_sequences[:,:-1],input_sequences[:,-1]
label = ku.to_categorical(label, num_classes=total_words)

Step 4:搭建模型

有了训练数据集后,我们就可以搭建需要的模型了:

model = Sequential()
model.add(Embedding(total_words, 300, input_length=max_sequence_len-1))
model.add(Bidirectional(LSTM(200, return_sequences = True)))
model.add(Dropout(0.2))
model.add(LSTM(100))
model.add(Dense(total_words/2, activation='relu', kernel_regularizer=regularizers.l2(0.001)))
model.add(Dense(total_words, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
print(model.summary())


history = model.fit(predictors, label, epochs=200, verbose=0)

其中,第一层是 embedding 层。第一个参数反映模型处理的单词数量,这里我们希望能够处理所有单词,所以赋值 total_words;第二个参数反映用于绘制单词向量的维数,可以随意调整,会获得不同的预测结果;第三个参数反映输入的序列长度,因为输入序列是原始序列中除最后一个字符外的所有数据,所以这里需要减去一。随后是 bidirectional LSTM 层以及 Dense 层。对于损失函数,我们设置为分类交叉熵;优化函数,我们选择 adam 算法。

Step 5:结果分析

对于训练后的效果,我们主要查看准确度和损失大小。

import matplotlib.pyplot as plt
acc = history.history['accuracy']
loss = history.history['loss']
epochs = range(len(acc))
plt.plot(epochs, acc, 'b', label='Training accuracy')
plt.title('Training accuracy')
plt.figure()
plt.plot(epochs, loss, 'b', label='Training Loss')
plt.title('Training loss')
plt.legend()
plt.show()

ddd.png


从曲线图可以看出,训练准确率不断提高,而损失则不断衰减。说明模型达到较好的性能。

Step 6:保存模型

通过以下代码可以对训练完成的模型进行保存,以方便进一步的部署。

# serialize model to JSON
model_json=model.to_json()
with open("model.json","w") as json_file:
json_file.write(model_json)
# serialize weights to HDF5
model.save_weights("model.h5")
print("Saved model to disk")

Step 7:进行预测

接下来,将应用训练好的模型进行单词预测以及生成故事。首先,用户输入初始语句,然后将该语句进行预处理,输入到 LSTM 模型中,得到对应的一个预测单词。重复这一过程,便能够生成对应的故事了。具体代码如下:

seed_text = "As i walked, my heart sank"
next_words = 100
for _ in range(next_words):
token_list = tokenizer.texts_to_sequences([seed_text])[0]
token_list = pad_sequences([token_list], maxlen=max_sequence_len-1, padding='pre')
predicted = model.predict_classes(token_list, verbose=0)
output_word = ""
for word, index in tokenizer.word_index.items():
if index == predicted:
output_word = word
break
seed_text += " " + output_word
print(seed_text)

生成故事如下:

As i walked, my heart sank until he was alarmed by the voice of the hunterand realised what could have happened with him he flew away the boy crunchedbefore it disguised herself as another effort to pull out the bush which he didthe next was a small tree which the child had to struggle a lot to pull outfinally the old man showed him a bigger tree and asked the child to pull itout the boy did so with ease and they walked on the morning she was askedhow she had slept as a while they came back with me

目录
相关文章
|
15天前
|
机器学习/深度学习 存储 人工智能
算法金 | LSTM 原作者带队,一个强大的算法模型杀回来了
**摘要:** 本文介绍了LSTM(长短期记忆网络)的发展背景和重要性,以及其创始人Sepp Hochreiter新推出的xLSTM。LSTM是为解决传统RNN长期依赖问题而设计的,广泛应用于NLP和时间序列预测。文章详细阐述了LSTM的基本概念、核心原理、实现方法和实际应用案例,包括文本生成和时间序列预测。此外,还讨论了LSTM与Transformer的竞争格局。最后,鼓励读者深入学习和探索AI领域。
25 7
算法金 | LSTM 原作者带队,一个强大的算法模型杀回来了
|
1天前
|
机器学习/深度学习 PyTorch 算法框架/工具
RNN、LSTM、GRU神经网络构建人名分类器(三)
这个文本描述了一个使用RNN(循环神经网络)、LSTM(长短期记忆网络)和GRU(门控循环单元)构建的人名分类器的案例。案例的主要目的是通过输入一个人名来预测它最可能属于哪个国家。这个任务在国际化的公司中很重要,因为可以自动为用户注册时提供相应的国家或地区选项。
|
1天前
|
机器学习/深度学习 数据采集
RNN、LSTM、GRU神经网络构建人名分类器(一)
这个文本描述了一个使用RNN(循环神经网络)、LSTM(长短期记忆网络)和GRU(门控循环单元)构建的人名分类器的案例。案例的主要目的是通过输入一个人名来预测它最可能属于哪个国家。这个任务在国际化的公司中很重要,因为可以自动为用户注册时提供相应的国家或地区选项。
|
5天前
|
机器学习/深度学习 自然语言处理 PyTorch
【自然语言处理NLP】Bert预训练模型、Bert上搭建CNN、LSTM模型的输入、输出详解
【自然语言处理NLP】Bert预训练模型、Bert上搭建CNN、LSTM模型的输入、输出详解
23 0
|
30天前
|
机器学习/深度学习 算法
【MATLAB】基于VMD-SSA-LSTM的回归预测模型
【MATLAB】基于VMD-SSA-LSTM的回归预测模型
41 4
|
1月前
|
机器学习/深度学习 算法
【MATLAB】基于EMD-PCA-LSTM的回归预测模型
【MATLAB】基于EMD-PCA-LSTM的回归预测模型
35 0
【MATLAB】基于EMD-PCA-LSTM的回归预测模型
|
1月前
|
机器学习/深度学习 数据可视化 TensorFlow
【视频】LSTM模型原理及其进行股票收盘价的时间序列预测讲解|附数据代码1
【视频】LSTM模型原理及其进行股票收盘价的时间序列预测讲解|附数据代码
|
1月前
|
机器学习/深度学习 存储 数据可视化
【视频】LSTM模型原理及其进行股票收盘价的时间序列预测讲解|附数据代码2
【视频】LSTM模型原理及其进行股票收盘价的时间序列预测讲解|附数据代码
|
1月前
|
机器学习/深度学习 自然语言处理 数据可视化
数据代码分享|PYTHON用NLP自然语言处理LSTM神经网络TWITTER推特灾难文本数据、词云可视化
数据代码分享|PYTHON用NLP自然语言处理LSTM神经网络TWITTER推特灾难文本数据、词云可视化
|
4天前
|
机器学习/深度学习 算法 数据可视化
m基于PSO-LSTM粒子群优化长短记忆网络的电力负荷数据预测算法matlab仿真
在MATLAB 2022a中,应用PSO优化的LSTM模型提升了电力负荷预测效果。优化前预测波动大,优化后预测更稳定。PSO借鉴群体智能,寻找LSTM超参数(如学习率、隐藏层大小)的最优组合,以最小化误差。LSTM通过门控机制处理序列数据。代码显示了模型训练、预测及误差可视化过程。经过优化,模型性能得到改善。
19 6

热门文章

最新文章