使用libsvm实现文本分类

简介:

文本分类,首先它是分类问题,应该对应着分类过程的两个重要的步骤,一个是使用训练数据集训练分类器,另一个就是使用测试数据集来评价分类器的分类精度。然而,作为文本分类,它还具有文本这样的约束,所以对于文本来说,需要额外的处理过程,我们结合使用libsvm从宏观上总结一下,基于libsvm实现文本分类实现的基本过程,如下所示:

  1. 选择文本训练数据集和测试数据集:训练集和测试集都是类标签已知的;
  2. 训练集文本预处理:这里主要包括分词、去停用词、建立词袋模型(倒排表);
  3. 选择文本分类使用的特征向量(词向量):最终的目标是使得最终选出的特征向量在多个类别之间具有一定的类别区分度,可以使用相关有效的技术去实现特征向量的选择,由于分词后得到大量的词,通过选择降维技术能很好地减少计算量,还能维持分类的精度;
  4. 输出libsvm支持的量化的训练样本集文件:类别名称、特征向量中每个词元素分别到数字编号的映射转换,以及基于类别和特征向量来量化文本训练集,能够满足使用libsvm训练所需要的数据格式;
  5. 测试数据集预处理:同样包括分词(需要和训练过程中使用的分词器一致)、去停用词、建立词袋模型(倒排表),但是这时需要加载训练过程中生成的特征向量,用特征向量去排除多余的不在特征向量中的词(也称为降维);
  6. 输出libsvm支持的量化的测试样本集文件:格式和训练数据集的预处理阶段的输出相同。
  7. 使用libsvm训练文本分类器:使用训练集预处理阶段输出的量化的数据集文件,这个阶段也需要做很多工作(后面会详细说明),最终输出分类模型文件
  8. 使用libsvm验证分类模型的精度:使用测试集预处理阶段输出的量化的数据集文件,和分类模型文件来验证分类的精度。
  9. 分类模型参数寻优:如果经过libsvm训练出来的分类模型精度很差,可以通过libsvm自带的交叉验证(Cross Validation)功能来实现参数的寻优,通过搜索参数取值空间来获取最佳的参数值,使分类模型的精度满足实际分类需要。

基于上面的分析,分别对上面每个步骤进行实现,最终完成一个分类任务。

数据集选择

我们选择了搜狗的语料库,可以参考后面的链接下载语料库文件。这里,需要注意的是,分别准备一个训练数据集和一个测试数据集,不要让两个数据集有交叉。例如,假设有C个类别,选择每个分类的下的N篇文档作为训练集,总共的训练集文档数量为C*N,剩下的每一类下M篇作为测试数据集使用,测试数据集总共文档数等于C*M。

数据集文本预处理

我们选择使用ICTCLAS分词器,使用该分词器可以不需要预先建立自己的词典,而且分词后已经标注了词性,可以根据词性对词进行一定程度过滤(如保留名词,删除量词、叹词等等对分类没有意义的词汇)。
下载ICTCLAS软件包,如果是在Win7 64位系统上使用Java实现分词,选择如下两个软件包:

  • 20131115123549_nlpir_ictclas2013_u20131115_release.zip
  • 20130416090323_Win-64bit-JNI-lib.zip

将第二个软件包中的NLPIR_JNI.dll文件拷贝到C:\Windows\System32目录下面,将第一个软件包中的Data目录和NLPIR.dll、NLPIR.lib、NLPIR.h、NLPIR.lib文件拷贝到Java工程根目录下面。
对于其他操作系统,可以到ICTCLAS网站(http://ictclas.nlpir.org/downloads)下载对应版本的软件包。
下面,我们使用Java实现分词,定义分词器接口,以便切换其他分词器实现时,容易扩展,如下所示:

1 package org.shirdrn.document.processor.common;
2
3 import java.io.File;
4 import java.util.Map;
5
6 public interface DocumentAnalyzer {
7 Map<String, Term> analyze(File file);
8 }

增加一个外部的停用词表,这个我们直接封装到抽象类AbstractDocumentAnalyzer中去了,该抽象类就是从一个指定的文件或目录读取停用词文件,将停用词加载到内存中,在分词的过程中对词进行进一步的过滤。然后基于上面的实现,给出包裹ICTCLAS分词器的实现,代码如下所示:

01 package org.shirdrn.document.processor.analyzer;
02
03 import java.io.BufferedReader;
04 import java.io.File;
05 import java.io.FileInputStream;
06 import java.io.IOException;
07 import java.io.InputStreamReader;
08 import java.util.HashMap;
09 import java.util.Map;
10
11 import kevin.zhang.NLPIR;
12
13 import org.apache.commons.logging.Log;
14 import org.apache.commons.logging.LogFactory;
15 import org.shirdrn.document.processor.common.DocumentAnalyzer;
16 import org.shirdrn.document.processor.common.Term;
17 import org.shirdrn.document.processor.config.Configuration;
18
19 public class IctclasAnalyzer extends AbstractDocumentAnalyzer implementsDocumentAnalyzer {
20
21 private static final Log LOG = LogFactory.getLog(IctclasAnalyzer.class);
22 private final NLPIR analyzer;
23
24 public IctclasAnalyzer(Configuration configuration) {
25 super(configuration);
26 analyzer = new NLPIR();
27 try {
28 boolean initialized = NLPIR.NLPIR_Init(".".getBytes(charSet), 1);
29 if(!initialized) {
30 throw new RuntimeException("Fail to initialize!");
31 }
32 } catch (Exception e) {
33 throw new RuntimeException("", e);
34 }
35 }
36
37 @Override
38 public Map<String, Term> analyze(File file) {
39 String doc = file.getAbsolutePath();
40 LOG.info("Process document: file=" + doc);
41 Map<String, Term> terms = new HashMap<String, Term>(0);
42 BufferedReader br = null;
43 try {
44 br = new BufferedReader(new InputStreamReader(newFileInputStream(file), charSet));
45 String line = null;
46 while((line = br.readLine()) != null) {
47 line = line.trim();
48 if(!line.isEmpty()) {
49 byte nativeBytes[] = analyzer.NLPIR_ParagraphProcess(line.getBytes(charSet), 1);
50 String content = new String(nativeBytes, 0, nativeBytes.length, charSet);
51 String[] rawWords = content.split("\\s+");
52 for(String rawWord : rawWords) {
53 String[] words = rawWord.split("/");
54 if(words.length == 2) {
55 String word = words[0];
56 String lexicalCategory = words[1];
57 Term term = terms.get(word);
58 if(term == null) {
59 term = new Term(word);
60 // TODO set lexical category
61 term.setLexicalCategory(lexicalCategory);
62 terms.put(word, term);
63 }
64 term.incrFreq();
65 LOG.debug("Got word: word=" + rawWord);
66 }
67 }
68 }
69 }
70 } catch (IOException e) {
71 e.printStackTrace();
72 } finally {
73 try {
74 if(br != null) {
75 br.close();
76 }
77 } catch (IOException e) {
78 LOG.warn(e);
79 }
80 }
81 return terms;
82 }
83
84 }

它是对一个文件进行读取,然后进行分词,去停用词,最后返回的Map包含了的集合,此属性包括词性(Lexical Category)、词频、TF等信息。
这样,遍历数据集目录和文件,就能去将全部的文档分词,最终构建词袋模型。我们使用Java中集合来存储文档、词、类别之间的关系,如下所示:

01 private int totalDocCount;
02 private final List<String> labels = new ArrayList<String>();
03 // Map<类别, 文档数量>
04 private final Map<String, Integer> labelledTotalDocCountMap = new HashMap<String, Integer>();
05 // Map<类别, Map<文档 ,Map<词, 词信息>>>
06 private final Map<String, Map<String, Map<String, Term>>> termTable =
07 new HashMap<String, Map<String, Map<String, Term>>>();
08 // Map<词 ,Map<类别, Set<文档>>>
09 private final Map<String, Map<String, Set<String>>> invertedTable =
10 new HashMap<String, Map<String, Set<String>>>();

基于训练数据集选择特征向量

上面已经构建好词袋模型,包括相关的文档和词等的关系信息。现在我们来选择用来建立分类模型的特征词向量,首先要选择一种度量,来有效地选择出特征词向量。基于论文《A comparative study on feature selection in text categorization》,我们选择基于卡方统计量(chi-square statistic, CHI)技术来实现选择,这里根据计算公式:

其中,公式中各个参数的含义,说明如下:

  • N:训练数据集文档总数
  • A:在一个类别中,包含某个词的文档的数量
  • B:在一个类别中,排除该类别,其他类别包含某个词的文档的数量
  • C:在一个类别中,不包含某个词的文档的数量
  • D:在一个类别中,不包含某个词也不在该类别中的文档的数量

要想进一步了解,可以参考这篇论文。
使用卡方统计量,为每个类别下的每个词都进行计算得到一个CHI值,然后对这个类别下的所有的词基于CHI值进行排序,选择出最大的topN个词(很显然使用堆排序算法更合适);最后将多个类别下选择的多组topN个词进行合并,得到最终的特征向量。
其实,这里可以进行一下优化,每个类别下对应着topN个词,在合并的时候可以根据一定的标准,将各个类别都出现的词给出一个比例,超过指定比例的可以删除掉,这样可以使特征向量在多个类别分类过程中更具有区分度。这里,我们只是做了个简单的合并。
我们看一下,用到的存储结构,使用Java的集合来存储:

1 // Map<label, Map<word, term>>
2 private final Map<String, Map<String, Term>> chiLabelToWordsVectorsMap = newHashMap<String, Map<String, Term>>(0);
3 // Map<word, term>, finally merged vector
4 private final Map<String, Term> chiMergedTermVectorMap = new HashMap<String, Term>(0);

下面,实现特征向量选择计算的实现,代码如下所示:

001 package org.shirdrn.document.processor.component.train;
002
003 import java.util.Iterator;
004 import java.util.Map;
005 import java.util.Map.Entry;
006 import java.util.Set;
007
008 import org.apache.commons.logging.Log;
009 import org.apache.commons.logging.LogFactory;
010 import org.shirdrn.document.processor.common.AbstractComponent;
011 import org.shirdrn.document.processor.common.Context;
012 import org.shirdrn.document.processor.common.Term;
013 import org.shirdrn.document.processor.utils.SortUtils;
014
015 public class FeatureTermVectorSelector extends AbstractComponent {
016
017 private static final Log LOG = LogFactory.getLog(FeatureTermVectorSelector.class);
018 private final int keptTermCountEachLabel;
019
020 public FeatureTermVectorSelector(Context context) {
021 super(context);
022 keptTermCountEachLabel = context.getConfiguration().getInt("processor.each.label.kept.term.count", 3000);
023 }
024
025 @Override
026 public void fire() {
027 // compute CHI value for selecting feature terms
028 // after sorting by CHI value
029 for(String label : context.getVectorMetadata().getLabels()) {
030 // for each label, compute CHI vector
031 LOG.info("Compute CHI for: label=" + label);
032 processOneLabel(label);
033 }
034
035 // sort and select CHI vectors
036 Iterator<Entry<String, Map<String, Term>>> chiIter =
037 context.getVectorMetadata().chiLabelToWordsVectorsIterator();
038 while(chiIter.hasNext()) {
039 Entry<String, Map<String, Term>> entry = chiIter.next();
040 String label = entry.getKey();
041 LOG.info("Sort CHI terms for: label=" + label + ", termCount=" + entry.getValue().size());
042 Entry<String, Term>[] a = sort(entry.getValue());
043 for (int i = 0; i < Math.min(a.length, keptTermCountEachLabel); i++) {
044 Entry<String, Term> termEntry = a[i];
045 // merge CHI terms for all labels
046 context.getVectorMetadata().addChiMergedTerm(termEntry.getKey(), termEntry.getValue());
047 }
048 }
049 }
050
051 @SuppressWarnings("unchecked")
052 private Entry<String, Term>[] sort(Map<String, Term> terms) {