Bag of Tricks for Efficient Text Classification 论文阅读及实战

简介: Bag of Tricks for Efficient Text Classification 论文阅读及实战

一、Fasttext算法综述


  • Fasttext是Facebook AI Research2016年推出的文本分类和词训练工具,其源码已经托管在Github上。Fasttext最大的特点是模型简单,只有一层的隐层以及输出层,因此训练速度非常快,在普通的CPU上可以实现分钟级别的训练,比深度模型的训练要快几个数量级。同时,在多个标准的测试数据集上,Fasttext在文本分类的准确率上,和现有的一些深度学习的方法效果相当或接近。


二、原理介绍及优化策略


1.Fasttext算法的主要功能有两个:


  • (1)训练词向量:词向量的训练相对与word2vec来说增加了subwords特性。subwords其实就是一个词的字符级的n-gram。例如单词"hello",长度至少为3的char-level的ngram有”hel”,”ell”,”llo”,”hell”,”ello”以及本身”hello”。每个n-gram都可以用一个dense的向量zg表示,于是整个单词”hello”就可以表示表示为:
    16.png

                  公式一.png


  • 具体细节可以参考论文Enriching Word Vectors with Subword Information那么把每个word,拆成若干个char-level的ngram表示有什么好处呢?答案是:丰富了词表示的层次。例如:”english-born”和”china-born”,从单词层面上看,是两个不同的单词,但是如果用char-level的n-gram来表示,都有相同的后缀”born”。因此这种表示方法可以学习到当两个词有相同的词缀时,其语义也具有一定的相似性。这种方法对英语等西方语言来说可能是奏效的。因为英语中很多相同前缀和后缀的单词,语义上确实有所相近。但对于中文来说,这种方法可能会有些问题。比如说,”原来”和”原则”,虽有相同前缀,但意义相去甚远。可能对中文来说,按照偏旁部首等字形的方式拆解可能会更有意义一些。


  • (2)文本分类:Fasttext的另一个功能是做文本分类。主要的原理在论文Bag of Tricks for Efficient Text Classification中有所阐述。其模型结构简单来说,就是一层word embedding的隐层+输出层。结构如下图所示:
    15.png

                         图1 Fasttext模型.png
14.png

                            图2 CBOW模型.png


图1是Fasttext的网络结构,其中W1到Wn表示document中每个词的word embedding表示。文章则可以用所有词的embedding累加后的均值表示,即:

13.png

                   公式二.png


最后从隐层再经过一次的非线性变换得到输出层的label。对比word2vec中的图2中的CBOW模型(continuous bag of word),可以发现两个模型之前非常的相似。不同之处在于,fasttext模型最后预测的是文章的类别label,而CBOW模型预测的是窗口中间的词w(t),Fasttext是有监督的学习,CBOW是无监督的学习。另外,CBOW模型中输入层只包括当前窗口内除中心词外的所有上下词汇,而fasttext模型中输入层是文章中的所有词。


2.网络结构介绍及模型的优化trick


2.1 模型结构介绍:


  • 和word2vec类似,fasttext本质上也可以看成是一个浅层的神经网络,因此其前向传播过程可描述如下:
    12.png

              前向传播.png


其中,z是最后输出层的输入向量,Wo表示从隐藏层到输出层的权重

  • 由于模型的最后输出是要预测的文本属于某个类别的概率,因此很自然的使用标准的softmax层,见下图:

11.png

                                              标准softmax.png


标准的softmax层用于多分类任务中,它将多个神经元的输出映射到(0,1)之间,可以当作概率来理解,从而进行多分类。在最后选取输出节点时,选取概率最大的节点作为最终的预测值。定义的交叉熵损失函数CE(y_target,y_predict)如下:
  10.png

                     损失函数.png


2.2 模型优化的几个trick


  • 2.2.1 当类别数较少时,直接套用softmax层并没有效率问题,但是当类别很多时,softmax层的计算就比较费时。为了加快训练过程,Fasttext同样也采用了和word2vec类似的方法。一种方法是使用hierarchical softmax,当类别数为K,word embedding的维度大小为d时,计算复杂度可以从O(Kd)降到O(dlog(K))。

9.jpg

                    分层softmax.jpg

因为在标准的softmax回归中,要计算y=j时的softmax概率P(y=i),需要对所有的K个概率做归一化,这在类别很大是会很耗时,于是分层softmax诞生了。分层softmax的基本思想:使用树的层级结构替代扁平化的标准softmax,使得在计算P(y=j)时,只需要计算一路路径上的所有节点的概率值即可,不需在意其他的节点。树的结构是根据类别的频数构造的霍夫曼树。K个不同的类别组成所有的叶子节点,K-1个内部节点作为内部参数,从根节点到某个叶子节点和边构成一条路径,路径的长度表示为L(yj)。于是,P(yj)可以写成下式:
  8.png

                                      类别概率.png

其中:
  • sigmoid():表示sigmoid函数
  • LC(n):表示n节点的左孩子
  • [[x]]是一个特殊的函数,定义如下:
    7.png

          特殊函数x.png

  • theta(yj,l):表示中间节点n(yj,l)的参数,X是softmax层的输入
    上图中,高亮的节点和边是从根节点到y2的路径,路径长度是L(y2)=4,所以P(y2)可表示为:
    6.png

                                         最后的概率.png


于是,从根节点到叶子节点y2,实际上作出3次的二分类的逻辑回归,通过分层的softmax,计算复杂度从O(K)降到O(log(K))。


  • 2.2.2 另一种方法是采用negative sampling,即每次从除当前label(正样本)外的其他label(负样本)中选择几个作为负样本,作为出现负样本的概率加到损失函数中,用公式可表达为:
    5.png

                                优化后的损失函数.png


其中,hi是第i个样本的隐藏层神经元数量,uj是Wo中第j行向量

  • 2.2.3 n-gram特征解决Fasttext模型丢失词汇顺序的问题,因为隐层是通过简单的求和取平均得到的。为了弥补这个不足,Fasttext增加了N-gram的特征。具体做法:把N-gram当成一个词,也用embedding向量来表示,在计算隐层时,把N-gram的embedding向量也加进去求和取平均。举个例子来说,假设某篇文章只有3个词,W1,W2,W3,N-gram的N=2即是bigram,w1、w2、w3以及w12、w23分别表示词W1、W2、W3和bigram W1-W2,W2-W3的embedding向量,那么文章的隐层可表示为:
    4.png

                       n-gram.png  

通过反向传播算法,就可以同时学到词的Embeding和n-gram的Embedding,具体的实现上,由于n-gram的量远比word大的多,完全存下所有的n-gram也不现实。Fasttext采用了Hash桶的方式,把所有的n-gram都哈希到buckets个桶中,哈希到同一个桶的所有n-gram共享一个embedding vector。如下图所示:
 
3.png

                                            Hash Buckets.png
      图中Win是Embedding矩阵,每行代表一个word或N-gram的词向量,其中前V行是word embeddings,后Buckets行是n-grams embeddings。每个n-gram经哈希函数哈希到0-bucket-1的位置,得到对应的embedding向量。用哈希的方式既能保证查找时O(1)的效率,又可能把内存消耗控制在O(bucket×dim)范围内。不过这种方法潜在的问题是存在哈希冲突,不同的n-gram可能会共享同一个embedding。如果桶大小取的足够大,这种影响会很小。


  • 2.2.4 对计算复杂度比较高的运算,Fasttext都采用了预计算的方法,先计算好值,使用的时候再查表,这是典型的空间或时间的优化思路。比如sigmoid函数的计算,源代码如下:
    2.png

sigmoid函数的计算.png


  • 2.2.5 在Negative Sampling中,Fasttext也采用了和word2vec类似的方法,即按照每个词的词频进行随机负采样,词频越大的词,被采样的概率越大。每个词被采样的概率并不是简单的按照词频在总量的占比,而是对词频先取根号,再算占比,公式如下:
    1.png

            负采样.png      

 其中,fw表示单词w的词频。取根号的目的是降低高频词汇的采样频率,同时增加低频词的采样频率。


三、Fasttext算法实战(代码运行在Linux系统中)


1第1步:获取分类文本,文本直接使用清华大学新闻文本,输出格式是:样本+样本标签
 2所使用的训练集和测试集已经分词好了,每个样本后会用Tab键隔开,打上样本标签,例如__label__sports
 3"""
 4
 5# 空缺
 6
 7
 8"""
 9第2步:利用fasttext进行分类,使用fasttext包
10"""
11
12# 训练模型
13classifier = fasttext.supervised(
14    "./dataset/news_fasttext_train.txt", "news_fasttext.model", label_prefix="__label__")
15# 训练好的模型
16# classifier = fasttext.load_model('new_fasttext.model.bin', label_prefix="__label__")
17
18# 测试模型
19result = classifier.test("./dataset/news_fasttext_test.txt")
20print("准确率:", result.precision)
21print("召回率:", result.recall)
22
23# fasttext只是对整个文本提供precision和recall,要统计不同的分类结果,需要自己实现
24
25classifier = fasttext.load_model(
26    'news_fasttext.model.bin', label_prefix='__label__')
27labels_right = []
28texts = []
29with open("./dataset/news_fasttext_test.txt") as f:
30    for line in f:
31        line = line.strip()
32        labels_right.append(line.split('\t')[1].replace("__label__", ""))
33        texts.append(line.split('\t')[0])
34
35labels_predict = [e[0] for e in classifier.predict(texts)]  # 预测标签值
36text_labels = list(set(labels_right))  # 实际标签值
37text_predict_labels = list(set(labels_predict))  # 去重后,预测标签值
38
39# print("预测标签值:", text_predict_labels)
40# print("真实标签值:", text_labels)
41
42A = dict.fromkeys(text_labels, 0)  # 预测正确的各个类的数目,真阳性TP
43B = dict.fromkeys(text_labels, 0)  # 真实测试数据集中各个类的数目
44C = dict.fromkeys(text_predict_labels, 0)  # 预测结果中各个类的数目,所有的预测结果
45for i in range(0, len(labels_right)):
46    B[labels_right[i]] += 1
47    C[labels_predict[i]] += 1
48    if labels_right[i] == labels_predict[i]:  # 判断是否是真阳性TP样本
49        A[labels_right[i]] += 1
50
51print("真阳性样本TP的类别数目A:", A)
52print("测试数据集中各个类别的数目B:", B)
53print("预测结果中各个类别的数目C:", C)
54# 计算准确率,召回率,F值
55for key in B:
56    try:
57        r = float(A[key]) / float(B[key])   # 召回率
58        p = float(A[key]) / float(C[key])  # 准确率
59        f = p * r * 2 / (p + r)   # f1值
60        print("%s:\t p:%f\t r:%f\t f:%f" % (key, p, r, f))
61    except:
62        print("错误:", key, "正确:", A.get(key, 0), "real:",
63              B.get(key, 0), "预测:", C.get(key, 0))


四、参考资料


[1] Bag of Tricks for Efficient Text Classification
[2] Enriching Word Vectors with Subword Information
[3] http://albertxiebnu.github.io/fasttext/
[4] https://www.jiqizhixin.com/articles/2018-12-03-6
[5] https://www.jianshu.com/p/ffa51250ba2e
[6] 训练集数据下载 https://pan.baidu.com/s/1jH7wyOY
[7] 测试集数据下载 https://pan.baidu.com/s/1slGlPgx

相关文章
|
人工智能 搜索推荐 云栖大会
解密!通义智文-你的AI阅读助手!
通义智文是基于通义大模型的AI阅读助手,网页阅读、论文阅读、图书阅读和自由阅读,用AI帮你读得多、读得快、读得懂。 通过文档场景化阅读、结构化导读、给我灵感、多文档处理等亮点功能和文档智能大小模型协同的核心技术。让AI帮你更准确,更深入,更专业的读懂文档,沉淀专属知识资产。 产品已于2023年10月31日在云栖大会正式对外发布,现免费公测全面开放。
2528 1
解密!通义智文-你的AI阅读助手!
|
人工智能 搜索推荐 算法
爱思唯尔的KBS——模板、投稿、返修、接收的总结
爱思唯尔的KBS——模板、投稿、返修、接收的总结
3593 3
|
监控 Unix Windows
Zabbix【部署 04】 Windows系统安装配置agent及agent2
Zabbix【部署 04】 Windows系统安装配置agent及agent2
1598 0
|
3月前
|
固态存储 关系型数据库 数据库
从Explain到执行:手把手优化PostgreSQL慢查询的5个关键步骤
本文深入探讨PostgreSQL查询优化的系统性方法,结合15年数据库优化经验,通过真实生产案例剖析慢查询问题。内容涵盖五大关键步骤:解读EXPLAIN计划、识别性能瓶颈、索引优化策略、查询重写与结构调整以及系统级优化配置。文章详细分析了慢查询对资源、硬件成本及业务的影响,并提供从诊断到根治的全流程解决方案。同时,介绍了索引类型选择、分区表设计、物化视图应用等高级技巧,帮助读者构建持续优化机制,显著提升数据库性能。最终总结出优化大师的思维框架,强调数据驱动决策与预防性优化文化,助力优雅设计取代复杂补救,实现数据库性能质的飞跃。
478 0
|
3月前
|
机器学习/深度学习 数据采集 人工智能
WebDancer:从零训练一个 DeepResearch 类智能体
WebDancer 是一款具备 Agentic 能力的智能体,能在开放网页环境中自主提问、搜索、推理并验证答案。它通过多步推理、信息整合与交叉验证解决复杂问题,如医学文献分析或政策追踪。WebDancer 采用 CRAWLQA 和 E2HQA 数据合成策略生成高质量训练数据,并结合 SFT(监督微调)+ RL(强化学习)双阶段训练方法,提升模型在动态环境中的适应性和泛化能力。其核心技术包括 ReAct 行为框架和 DAPO 强化学习算法,确保路径优化与策略稳定性。未来,WebDancer 将接入 Browser 工具链,拓展至代码沙盒、长文本写作等应用场景,进一步向通用智能体演进。
1049 27
|
5月前
|
机器学习/深度学习 人工智能 运维
AI“捕风捉影”:深度学习如何让网络事件检测更智能?
AI“捕风捉影”:深度学习如何让网络事件检测更智能?
129 8
|
10月前
|
人工智能 安全 Cloud Native
|
机器学习/深度学习 自然语言处理 算法
深度学习基础知识:介绍深度学习的发展历程、基本概念和主要应用
深度学习基础知识:介绍深度学习的发展历程、基本概念和主要应用
6809 0
|
关系型数据库 MySQL Java
解决com.mysql.cj.jdbc.exceptions.PacketTooBigException: Packet for query is too large
这篇文章提供了解决MySQL JDBC驱动中`com.mysql.cj.jdbc.exceptions.PacketTooBigException: Packet for query is too large`错误的步骤,主要是通过增加配置文件中的`max_allowed_packet`参数值并重启服务来允许更大的数据包传输。
解决com.mysql.cj.jdbc.exceptions.PacketTooBigException: Packet for query is too large
|
人工智能 编解码 机器人
硬核解读Stable Diffusion(3)
硬核解读Stable Diffusion