CNN-RNN中文文本分类,基于TensorFlow 实现

简介:

使用卷积神经网络以及循环神经网络进行中文文本分类

CNN做句子分类的论文可以参看:

https://arxiv.org/abs/1408.5882

还可以去读dennybritz大牛的博客:

http://www.wildml.com/2015/12/implementing-a-cnn-for-text-classification-in-tensorflow/

以及字符级CNN的论文:

https://arxiv.org/abs/1509.01626

本文是基于TensorFlow在中文数据集上的简化实现,使用了字符级CNN和RNN对中文文本进行分类,达到了较好的效果。

使用THUCNews的一个子集进行训练与测试,数据集请自行到THUCTC:一个高效的中文文本分类工具包

下载,请遵循数据提供方的开源协议。

本次训练使用了其中的10个分类,每个分类6500条数据。

类别如下:

体育, 财经, 房产, 家居, 教育, 科技, 时尚, 时政, 游戏, 娱乐

数据集划分如下:

 ●  训练集: 5000*10
 ●  验证集: 500*10
 ●  测试集: 1000*10

从原数据集生成子集的过程请参看helper下的两个脚本。其中,copy_data.sh用于从每个分类拷贝6500个文件,cnews_group.py用于将多个文件整合到一个文件中。执行该文件后,得到三个数据文件:

 ●  cnews.train.txt: 训练集(50000条)
 ●  cnews.val.txt: 验证集(5000条)

 ●  cnews.test.txt: 测试集(10000条)

预处理

data/cnews_loader.py为数据的预处理文件。

 ●  read_file() : 读取文件数据;
 ●  build_vocab() : 构建词汇表,使用字符级的表示,这一函数会将词汇表存储下来,避免每一次重复处理;
 ●  read_vocab() : 读取上一步存储的词汇表,转换为 {词:id} 表示;
 ●  read_category() : 将分类目录固定,转换为 {类别: id} 表示;
 ●  to_words() : 将一条由id表示的数据重新转换为文字;
 ●  process_file() : 将数据集从文字转换为固定长度的id序列表示;
 ●  batch_iter() : 为神经网络的训练准备经过shuffle的批次的数据。

经过数据预处理,数据的格式如下:

de348a9410f3e0ec617f7891f4292b34b4400030

CNN模型

具体参看cnn_model.py的实现。

大致结构如下:

18d311743a97f07b86de5f9651ff95c50fa56359

训练与验证

运行 python run_cnn.py train,可以开始训练。

3fa6224a96884301e095adc2d732429ea71c056e

在验证集上的最佳效果为94.12%,且只经过了3轮迭代就已经停止。

准确率和误差如图所示:

9ebac8c04c67c54c146c9ab9eb6f470f5db5209c

测试

运行 python run_cnn.py test 在测试集上进行测试。

e820e5f7847abca4e663f65c48fc28a89ff5c2bb

在测试集上的准确率达到了96.04%,且各类的precision, recall和f1-score都超过了0.9。

从混淆矩阵也可以看出分类效果非常优秀。

RNN循环神经网络

配置项

RNN可配置的参数如下所示,在rnn_model.py中。

883fd9dff2e973a3b72b8bea392be5464f1411da

RNN模型

具体参看rnn_model.py的实现。

大致结构如下:

ffa81196f2992d6a79f681d873ef176be1a9d942

训练与验证

这部分的代码与 run_cnn.py极为相似,只需要将模型和部分目录稍微修改。

运行 python run_rnn.py train,可以开始训练。

若之前进行过训练,请把tensorboard/textrnn删除,避免TensorBoard多次训练结果重叠。

897eec88e9ab55de9ac27077aca1886df3eb1bbe

在验证集上的最佳效果为91.42%,经过了8轮迭代停止,速度相比CNN慢很多。

准确率和误差如图所示:

8d6f7d774722192775312e52bad6e70e69e4df1b

测试

运行 python run_rnn.py test 在测试集上进行测试。

230b5b317750af4ebd04d2070bdf499a4aee98dd

在测试集上的准确率达到了94.22%,且各类的precision, recall和f1-score,除了家居这一类别,都超过了0.9。

从混淆矩阵可以看出分类效果非常优秀。

对比两个模型,可见RNN除了在家居分类的表现不是很理想,其他几个类别较CNN差别不大。

还可以通过进一步的调节参数,来达到更好的效果。

为方便预测,repo 中 predict.py 提供了 CNN 模型的预测方法。


原文发布时间为:2018-10-18

本文来自云栖社区合作伙伴“大数据挖掘DT机器学习”,了解相关信息可以关注“大数据挖掘DT机器学习”。

相关文章
|
3月前
|
机器学习/深度学习 算法 TensorFlow
文本分类识别Python+卷积神经网络算法+TensorFlow模型训练+Django可视化界面
文本分类识别Python+卷积神经网络算法+TensorFlow模型训练+Django可视化界面
54 0
文本分类识别Python+卷积神经网络算法+TensorFlow模型训练+Django可视化界面
|
4月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
【Python深度学习】Tensorflow+CNN进行人脸识别实战(附源码和数据集)
【Python深度学习】Tensorflow+CNN进行人脸识别实战(附源码和数据集)
97 0
|
5月前
|
机器学习/深度学习 数据挖掘 TensorFlow
keras tensorflow 搭建CNN-LSTM神经网络的住宅用电量预测 完整代码数据
keras tensorflow 搭建CNN-LSTM神经网络的住宅用电量预测 完整代码数据
51 0
|
6月前
|
机器学习/深度学习 算法 TensorFlow
【深度学习】实验14 使用CNN完成MNIST手写体识别(TensorFlow)
【深度学习】实验14 使用CNN完成MNIST手写体识别(TensorFlow)
63 0
|
8月前
|
机器学习/深度学习 存储 自然语言处理
基于 LSTM 进行多类文本分类( TensorFlow 2.0)
基于 LSTM 进行多类文本分类( TensorFlow 2.0)
|
机器学习/深度学习 TensorFlow 算法框架/工具
优达学城深度学习之六——TensorFlow实现卷积神经网络
优达学城深度学习之六——TensorFlow实现卷积神经网络
优达学城深度学习之六——TensorFlow实现卷积神经网络
|
机器学习/深度学习 存储 TensorFlow
直观理解并使用Tensorflow实现Seq2Seq模型的注意机制(下)
直观理解并使用Tensorflow实现Seq2Seq模型的注意机制
160 0
直观理解并使用Tensorflow实现Seq2Seq模型的注意机制(下)
|
机器学习/深度学习 数据采集 自然语言处理
直观理解并使用Tensorflow实现Seq2Seq模型的注意机制(上)
直观理解并使用Tensorflow实现Seq2Seq模型的注意机制
232 1
直观理解并使用Tensorflow实现Seq2Seq模型的注意机制(上)
|
机器学习/深度学习 算法 测试技术
从零开始构建:使用CNN和TensorFlow进行人脸特征检测
从零开始构建:使用CNN和TensorFlow进行人脸特征检测
157 0
从零开始构建:使用CNN和TensorFlow进行人脸特征检测
|
机器学习/深度学习 数据可视化 TensorFlow
使用TensorFlow Probability实现最大似然估计
TensorFlow Probability是一个构建在TensorFlow之上的Python库。它将我们的概率模型与现代硬件(例如GPU)上的深度学习结合起来。
119 1