TensorFlow教程之完整教程 2.8 递归神经网络

简介:

本文档为TensorFlow参考文档,本转载已得到TensorFlow中文社区授权。


循环神经网络

介绍

可以在 this great article 查看循环神经网络(RNN)以及 LSTM 的介绍。

语言模型

此教程将展示如何在高难度的语言模型中训练循环神经网络。该问题的目标是获得一个能确定语句概率的概率模型。为了做到这一点,通过之前已经给出的词语来预测后面的词语。我们将使用 PTB(Penn Tree Bank) 数据集,这是一种常用来衡量模型的基准,同时它比较小而且训练起来相对快速。

语言模型是很多有趣难题的关键所在,比如语音识别,机器翻译,图像字幕等。它很有意思--可以参看 here

本教程的目的是重现 Zaremba et al., 2014 的成果,他们在 PTB 数据集上得到了很棒的结果。

教程文件

本教程使用的下面文件的目录是 models/rnn/ptb:

文件 作用
ptb_word_lm.py 在 PTB 数据集上训练一个语言模型.
reader.py 读取数据集.

下载及准备数据

本教程需要的数据在 data/ 路径下,来源于 Tomas Mikolov 网站上的 PTB 数据集http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz

该数据集已经预先处理过并且包含了全部的 10000 个不同的词语,其中包括语句结束标记符,以及标记稀有词语的特殊符号 (<unk>) 。我们在 reader.py 中转换所有的词语,让他们各自有唯一的整型标识符,便于神经网络处理。

模型

LSTM

模型的核心由一个 LSTM 单元组成,其可以在某时刻处理一个词语,以及计算语句可能的延续性的概率。网络的存储状态由一个零矢量初始化并在读取每一个词语后更新。而且,由于计算上的原因,我们将以 batch_size 为最小批量来处理数据。

基础的伪代码就像下面这样:

lstm = rnn_cell.BasicLSTMCell(lstm_size)
# 初始化 LSTM 存储状态.
state = tf.zeros([batch_size, lstm.state_size])

loss = 0.0
for current_batch_of_words in words_in_dataset:
    # 每次处理一批词语后更新状态值.
    output, state = lstm(current_batch_of_words, state)

    # LSTM 输出可用于产生下一个词语的预测
    logits = tf.matmul(output, softmax_w) + softmax_b
    probabilities = tf.nn.softmax(logits)
    loss += loss_function(probabilities, target_words)

截断反向传播

为使学习过程易于处理,通常的做法是将反向传播的梯度在(按时间)展开的步骤上照一个固定长度(num_steps)截断。 通过在一次迭代中的每个时刻上提供长度为 num_steps 的输入和每次迭代完成之后反向传导,这会很容易实现。

一个简化版的用于计算图创建的截断反向传播代码:

# 一次给定的迭代中的输入占位符.
words = tf.placeholder(tf.int32, [batch_size, num_steps])

lstm = rnn_cell.BasicLSTMCell(lstm_size)
# 初始化 LSTM 存储状态.
initial_state = state = tf.zeros([batch_size, lstm.state_size])

for i in range(len(num_steps)):
    # 每处理一批词语后更新状态值.
    output, state = lstm(words[:, i], state)

    # 其余的代码.
    # ...

final_state = state

下面展现如何实现迭代整个数据集:

# 一个 numpy 数组,保存每一批词语之后的 LSTM 状态.
numpy_state = initial_state.eval()
total_loss = 0.0
for current_batch_of_words in words_in_dataset:
    numpy_state, current_loss = session.run([final_state, loss],
        # 通过上一次迭代结果初始化 LSTM 状态.
        feed_dict={initial_state: numpy_state, words: current_batch_of_words})
    total_loss += current_loss

输入

在输入 LSTM 前,词语 ID 被嵌入到了一个密集的表示中(查看 矢量表示教程)。这种方式允许模型高效地表示词语,也便于写代码:

# embedding_matrix 张量的形状是: [vocabulary_size, embedding_size]
word_embeddings = tf.nn.embedding_lookup(embedding_matrix, word_ids)

嵌入的矩阵会被随机地初始化,模型会学会通过数据分辨不同词语的意思。

损失函数

我们想使目标词语的平均负对数概率最小

实现起来并非很难,而且函数 sequence_loss_by_example 已经有了,可以直接使用。

论文中的典型衡量标准是每个词语的平均困惑度(perplexity),计算式为

同时我们会观察训练过程中的困惑度值(perplexity)。

多个 LSTM 层堆叠

要想给模型更强的表达能力,可以添加多层 LSTM 来处理数据。第一层的输出作为第二层的输入,以此类推。

类 MultiRNNCell 可以无缝的将其实现:

lstm = rnn_cell.BasicLSTMCell(lstm_size)
stacked_lstm = rnn_cell.MultiRNNCell([lstm] * number_of_layers)

initial_state = state = stacked_lstm.zero_state(batch_size, tf.float32)
for i in range(len(num_steps)):
    # 每次处理一批词语后更新状态值.
    output, state = stacked_lstm(words[:, i], state)

    # 其余的代码.
    # ...

final_state = state

编译并运行代码

首先需要构建库,在 CPU 上编译:

bazel build -c opt tensorflow/models/rnn/ptb:ptb_word_lm

如果你有一个强大的 GPU,可以运行:

bazel build -c opt --config=cuda tensorflow/models/rnn/ptb:ptb_word_lm

运行模型:

bazel-bin/tensorflow/models/rnn/ptb/ptb_word_lm \
  --data_path=/tmp/simple-examples/data/ --alsologtostderr --model small

教程代码中有 3 个支持的模型配置参数:"small", "medium" 和 "large"。它们指的是 LSTM 的大小,以及用于训练的超参数集。

模型越大,得到的结果应该更好。在测试集中 small 模型应该可以达到低于 120 的困惑度(perplexity),large 模型则是低于 80,但它可能花费数小时来训练。

除此之外?

还有几个优化模型的技巧没有提到,包括:

  • 随时间降低学习率,
  • LSTM 层间 dropout.

继续学习和更改代码以进一步改善模型吧。

相关文章
|
7月前
|
安全 网络协议 算法
Nmap网络扫描工具详细使用教程
Nmap 是一款强大的网络发现与安全审计工具,具备主机发现、端口扫描、服务识别、操作系统检测及脚本扩展等功能。它支持多种扫描技术,如 SYN 扫描、ARP 扫描和全端口扫描,并可通过内置脚本(NSE)进行漏洞检测与服务深度枚举。Nmap 还提供防火墙规避与流量伪装能力,适用于网络管理、渗透测试和安全研究。
1115 1
|
9月前
|
JSON 监控 API
在线网络PING接口检测服务器连通状态免费API教程
接口盒子提供免费PING检测API,可测试域名或IP的连通性与响应速度,支持指定地域节点,适用于服务器运维和网络监控。
1502 0
|
机器学习/深度学习 人工智能 算法
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
1312 55
|
数据采集 存储 监控
Python 原生爬虫教程:网络爬虫的基本概念和认知
网络爬虫是一种自动抓取互联网信息的程序,广泛应用于搜索引擎、数据采集、新闻聚合和价格监控等领域。其工作流程包括 URL 调度、HTTP 请求、页面下载、解析、数据存储及新 URL 发现。Python 因其丰富的库(如 requests、BeautifulSoup、Scrapy)和简洁语法成为爬虫开发的首选语言。然而,在使用爬虫时需注意法律与道德问题,例如遵守 robots.txt 规则、控制请求频率以及合法使用数据,以确保爬虫技术健康有序发展。
1529 31
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【害虫识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
害虫识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了12种常见的害虫种类数据集【"蚂蚁(ants)", "蜜蜂(bees)", "甲虫(beetle)", "毛虫(catterpillar)", "蚯蚓(earthworms)", "蜚蠊(earwig)", "蚱蜢(grasshopper)", "飞蛾(moth)", "鼻涕虫(slug)", "蜗牛(snail)", "黄蜂(wasp)", "象鼻虫(weevil)"】 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Djan
728 1
基于Python深度学习的【害虫识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
|
域名解析 API PHP
VM虚拟机全版本网盘+免费本地网络穿透端口映射实时同步动态家庭IP教程
本文介绍了如何通过网络穿透技术让公网直接访问家庭电脑,充分发挥本地硬件性能。相比第三方服务受限于转发带宽,此方法利用自家宽带实现更高效率。文章详细讲解了端口映射教程,包括不同网络环境(仅光猫、光猫+路由器)下的设置步骤,并提供实时同步动态IP的两种方案:自建服务器或使用三方API接口。最后附上VM虚拟机全版本下载链接,便于用户在穿透后将服务运行于虚拟环境中,提升安全性与适用性。
919 7
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【蘑菇识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
蘑菇识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了9种常见的蘑菇种类数据集【"香菇(Agaricus)", "毒鹅膏菌(Amanita)", "牛肝菌(Boletus)", "网状菌(Cortinarius)", "毒镰孢(Entoloma)", "湿孢菌(Hygrocybe)", "乳菇(Lactarius)", "红菇(Russula)", "松茸(Suillus)"】 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Django框架搭建了一个Web网页平台可视化操作界面,
1278 11
基于Python深度学习的【蘑菇识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
|
机器学习/深度学习 人工智能 算法
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
宠物识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了37种常见的猫狗宠物种类数据集【'阿比西尼亚猫(Abyssinian)', '孟加拉猫(Bengal)', '暹罗猫(Birman)', '孟买猫(Bombay)', '英国短毛猫(British Shorthair)', '埃及猫(Egyptian Mau)', '缅因猫(Maine Coon)', '波斯猫(Persian)', '布偶猫(Ragdoll)', '俄罗斯蓝猫(Russian Blue)', '暹罗猫(Siamese)', '斯芬克斯猫(Sphynx)', '美国斗牛犬
704 29
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
|
监控 Linux PHP
【02】客户端服务端C语言-go语言-web端PHP语言整合内容发布-优雅草网络设备监控系统-2月12日优雅草简化Centos stream8安装zabbix7教程-本搭建教程非docker搭建教程-优雅草solution
【02】客户端服务端C语言-go语言-web端PHP语言整合内容发布-优雅草网络设备监控系统-2月12日优雅草简化Centos stream8安装zabbix7教程-本搭建教程非docker搭建教程-优雅草solution
525 20
|
前端开发 小程序 Java
uniapp-网络数据请求全教程
这篇文档介绍了如何在uni-app项目中使用第三方包发起网络请求
1026 3