·阅读摘要:
本文提出基于Seq2Seq模型,提出CNN-RNN模型应用于多标签文本分类。论文表示CNN-RNN模型在大型数据集上表现的效果很好,在小数据集效果不好。
·参考文献:
[1] Ensemble Application of Convolutional and Recurrent Neural Networks for Multi-label Text Categorization
[2] Seq2Seq模型讲解,参考博客:【多标签文本分类】代码详解Seq2Seq模型
本文的收获有三:
1、CNN-RNN模型;
2、多标签数据集Reuters-21578;
3、多标签评价指标:one-error 、hamming loss、Precision、Recall、F1
[1] CNN-RNN模型图
如下图:模型很简单,左边是一个TextCNN模型,右边是一个解码器Decoder。
【注一】:在理解Seq2Seq的基础上,CNN-RNN模型很好理解。
[2] 多标签数据集Reuters-21578
多标签数据集比较难得,获取数据集Reuters-21578
,可以使用如下代码:
import nltk import pandas as pd nltk.download('reuters') nltk.download('punkt') # Extract fileids from the reuters corpus fileids = reuters.fileids() # Initialize empty lists to store categories and raw text categories = [] text = [] # Loop through each file id and collect each files categories and raw text for file in fileids: categories.append(reuters.categories(file)) text.append(reuters.raw(file)) # Combine lists into pandas dataframe. reutersDf is the final dataframe. reutersDf = pd.DataFrame({'ids':fileids, 'categories':categories, 'text':text})
[3] 多标签文本分类评价指标
one-error
:统计top1的预测标签不在实际标签中的实例的比例;
hamming loss
:计算预测标签和相关标签的对称差异,并计算其差异在标签空间中的分数;