开发者社区> 致Great_VIP> 正文

图神经网络14-TextGCN:基于图神经网络的文本分类

简介: 图神经网络14-TextGCN:基于图神经网络的文本分类
+关注继续查看

论文题目:Graph Convolutional Networks for Text Classification

论文地址:https://arxiv.org/pdf/1809.05679.pdf

论文代码:https://github.com/yao8839836/text_gcn

发表时间:AAAI 2019

ps:注意这篇论文作者在2018年已经公开在arxiv,我们再此不讨论预训练模型的事情 ^_^


论文摘要与简介


文本分类是自然语言处理过程中一个非常重要和经典的问题,在论文和实践过程中可以说经久不衰的任务。或多或少接触NLP的同学,应该比较清楚目前文本分类的模型众多,比如Text-RNN(LSTM),Text-CNN等,但是当时很少有关于将神经网络用于文本分类的任务中。


本文提出一种将图卷积网络模型用于文本分类的模型,主要思路为基于词语共现以及文本单词之间的关系构建语料库中文本的Graph,然后将GCN学习文本的表示用于文本分类。通过多个基准数据集实验表明,Text-GCN无需额外的单词嵌入或者先验知识就能够取得由于最新的文本分类方法。另一方面,Text-GCN还能够学习和预测词语与文档的嵌入表示。


论文动机与相关工作


图形结构在自然语言处理任务的文本数据中有许多有趣的应用,如语义角色标记(Titov2017)、关系分类(Li,Jin和Luo2018)和机器翻译(Bastings等)。

传统上,针对文本分类的模型一直侧重于单词嵌入的有效性和用于文档嵌入的聚合单词嵌入。这些词嵌入可以是无监督的预训练嵌入(例如word2vec或Glove),然后将其输入分类器中。最近,诸如CNN和RNN的深度学习模型已经成为有用的文本编码器。在这两种情况下,文本表示都是从单词嵌入中学习的。本文作者建议同时学习单词和文档嵌入以进行文本分类。


在本文之前其实也有GCN用于文本分类的研究,但是大部分工作都是将文档或者词语看做节点,相比之下,本文在构建语料库图时,我们将文档和单词视为节点(因此是异构图),并且不需要文档间的关系。本文提出的Text-GCN,获取给定的文档和单词的语料库,并构造一个图形,其中文档和单词为节点(有关的详细构建,我们稍后讨论)。利用此构造的图,`Text-GCN·利用图卷积网络来学习更好的节点表示(单词和文档的表示)。然后可以将这些更新的表示形式输入到分类器中。


GCN:Graph Convolutional Networks


我们快速对GCN进行回顾下,原文可以查看Semi-supervised classification with graph convolutional networks

首先我们定义下GCN的输入Graph。图image,由节点集合image和边集合image组成。连接节点的边可以用林局长了表示image。如果image不为空,那么表示节点image和节点image之间存在关系,并且权重为image的具体数值。GCN还向节点添加自环,因此邻接矩阵变为:

image

另外,每个节点都可以有一个向量表示,这个向量可以认为是该节点的特征。因此我们还有一个矩阵image,其中image是第image个节点的节点特征,image代表特征向量的大小。对于模型训练需要一个权重矩阵image,其中image是输出维度大小。最后我们还需要引入一个节点的度对角矩阵image,对角线每个值代表了节点的度大小。

42.png


有了以上条件和基础,我们可以给出GCN层的公式表示了:

image=image

我们一步步解释下这个公式。其中image,代表了输入节点的隐层输出向量表示。另外注意image本质上是邻接矩阵,但是通过节点的度进行了归一化。

从上面可以看到,GCN本质上是学习了节点邻居和节点本身的节点表示形式(请记住自循环)。

GCN层允许节点从跳远的其他节点接收信息。GCN属于一类图形神经网络,称为消息传递网络,其中消息(在这种情况下,边缘权重乘以节点表示形式)在邻居之间传递。我们可以将这些消息传递网络视为帮助学习节点表示的方法,该节点表示法考虑了其图结构的附近邻居。因此,图的构造方式,即在哪些节点之间形成哪些边,非常重要。接下来我们讨论下图卷积网络如何用于文本分类,文本图如何构造。


Text-GCN:基于图神经网络的文本分类


43.png


文本Graph的构建


构造“文本”图的细节如下。首先,节点总数是文档image数加上不同词语image的个数。节点特征矩阵是恒等矩阵image每个节点表示都是一个one-hot向量。同样,邻接矩阵(文档和单词节点之间的边缘)定义如下


44.png


image是包含单词i和单词j的滑动窗口的数量,而image是包含单词i的滑动窗口的数量。image是滑动窗口的总数。


文本与词语之前的关系比较好刻画,文中直接采用我们常见的Tfidf来构建文档与词的边。对于词与词的关系采用PMIimage是两个单词节点之间的逐点互信息,用于查看两个单词的共现次数。用于计算共现的窗口大小是模型的超参数。在本文中,作者将其设置为20。直观地,图造尝试将相似的单词和文档放置在图形中彼此靠近的位置。


Text-GCN模型


完成文本Graph构建后,作者只需运行两层GCN,然后运行softmax函数来预测标签。公

式为:

image

对于损失函数,使用交叉熵损失。


论文实验


作者将他们的模型与 CNN、LSTM 变体和其他基线词或段落嵌入模型进行了比较。比较是在5个数据集上执行的。


45.png


  • 20NG超过18,000个文档平均分布在20个类别中
  • Ohsumed共有23种疾病类别的7000多种心血管疾病摘要
  • R52是Reuters-21578的子集,该文件于1987年出现在路透社新闻专栏中。大约10000个文档,涉及52个类别
  • R8与上述相同,但具有7500个文档和8个类别
  • MR电影评论数据集,包含10000条评论和两个类别:正面情绪和负面情绪

实验结果如下表所示:


46.png


从结果可以看出,TextGCN与CNN,LSTM和其他基准相比,每个数据集的效果最佳或接近最佳。这种性能纯粹来自上一节中定义的边缘和边缘权重,没有输入词向量,可见效果很好~。


GCN网络学习到的向量表示


为了获得对学习的表示的一些见解,作者展示了通过获取的文档嵌入的t-SNE可视化 TextGCN。我们可以看到,即使在应用了GCN的一层之后,文档嵌入也能够很好地区分自己。


47.png


更具体地讲,我们还可以使用中的嵌入来查看每个类的前10个单词的结果 TextGCN。我们可以看到该模型能够预测每个类别的相关词。


48.png


论文核心代码


  • Text Graph的构建

https://github.com/yao8839836/text_gcn/blob/master/build_graph.py

'''
Doc word heterogeneous graph
'''
# word co-occurence with context windows
window_size = 20
windows = []
for doc_words in shuffle_doc_words_list:
    words = doc_words.split()
    length = len(words)
    if length <= window_size:
        windows.append(words)
    else:
        # print(length, length - window_size + 1)
        for j in range(length - window_size + 1):
            window = words[j: j + window_size]
            windows.append(window)
            # print(window)
word_window_freq = {}
for window in windows:
    appeared = set()
    for i in range(len(window)):
        if window[i] in appeared:
            continue
        if window[i] in word_window_freq:
            word_window_freq[window[i]] += 1
        else:
            word_window_freq[window[i]] = 1
        appeared.add(window[i])
word_pair_count = {}
for window in windows:
    for i in range(1, len(window)):
        for j in range(0, i):
            word_i = window[i]
            word_i_id = word_id_map[word_i]
            word_j = window[j]
            word_j_id = word_id_map[word_j]
            if word_i_id == word_j_id:
                continue
            word_pair_str = str(word_i_id) + ',' + str(word_j_id)
            if word_pair_str in word_pair_count:
                word_pair_count[word_pair_str] += 1
            else:
                word_pair_count[word_pair_str] = 1
            # two orders
            word_pair_str = str(word_j_id) + ',' + str(word_i_id)
            if word_pair_str in word_pair_count:
                word_pair_count[word_pair_str] += 1
            else:
                word_pair_count[word_pair_str] = 1
row = []
col = []
weight = []
# pmi as weights
num_window = len(windows)
for key in word_pair_count:
    temp = key.split(',')
    i = int(temp[0])
    j = int(temp[1])
    count = word_pair_count[key]
    word_freq_i = word_window_freq[vocab[i]]
    word_freq_j = word_window_freq[vocab[j]]
    pmi = log((1.0 * count / num_window) /
              (1.0 * word_freq_i * word_freq_j/(num_window * num_window)))
    if pmi <= 0:
        continue
    row.append(train_size + i)
    col.append(train_size + j)
    weight.append(pmi)
# word vector cosine similarity as weights
'''
for i in range(vocab_size):
    for j in range(vocab_size):
        if vocab[i] in word_vector_map and vocab[j] in word_vector_map:
            vector_i = np.array(word_vector_map[vocab[i]])
            vector_j = np.array(word_vector_map[vocab[j]])
            similarity = 1.0 - cosine(vector_i, vector_j)
            if similarity > 0.9:
                print(vocab[i], vocab[j], similarity)
                row.append(train_size + i)
                col.append(train_size + j)
                weight.append(similarity)
'''
# doc word frequency
doc_word_freq = {}
for doc_id in range(len(shuffle_doc_words_list)):
    doc_words = shuffle_doc_words_list[doc_id]
    words = doc_words.split()
    for word in words:
        word_id = word_id_map[word]
        doc_word_str = str(doc_id) + ',' + str(word_id)
        if doc_word_str in doc_word_freq:
            doc_word_freq[doc_word_str] += 1
        else:
            doc_word_freq[doc_word_str] = 1
for i in range(len(shuffle_doc_words_list)):
    doc_words = shuffle_doc_words_list[i]
    words = doc_words.split()
    doc_word_set = set()
    for word in words:
        if word in doc_word_set:
            continue
        j = word_id_map[word]
        key = str(i) + ',' + str(j)
        freq = doc_word_freq[key]
        if i < train_size:
            row.append(i)
        else:
            row.append(i + vocab_size)
        col.append(train_size + j)
        idf = log(1.0 * len(shuffle_doc_words_list) /
                  word_doc_freq[vocab[j]])
        weight.append(freq * idf)
        doc_word_set.add(word)
node_size = train_size + vocab_size + test_size
adj = sp.csr_matrix(
    (weight, (row, col)), shape=(node_size, node_size))


  • Text GCN 实现

https://github.com/yao8839836/text_gcn/blob/master/models.py

from layers import *
from metrics import *
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
class Model(object):
    def __init__(self, **kwargs):
        allowed_kwargs = {'name', 'logging'}
        for kwarg in kwargs.keys():
            assert kwarg in allowed_kwargs, 'Invalid keyword argument: ' + kwarg
        name = kwargs.get('name')
        if not name:
            name = self.__class__.__name__.lower()
        self.name = name
        logging = kwargs.get('logging', False)
        self.logging = logging
        self.vars = {}
        self.placeholders = {}
        self.layers = []
        self.activations = []
        self.inputs = None
        self.outputs = None
        self.loss = 0
        self.accuracy = 0
        self.optimizer = None
        self.opt_op = None
    def _build(self):
        raise NotImplementedError
    def build(self):
        """ Wrapper for _build() """
        with tf.variable_scope(self.name):
            self._build()
        # Build sequential layer model
        self.activations.append(self.inputs)
        for layer in self.layers:
            hidden = layer(self.activations[-1])
            self.activations.append(hidden)
        self.outputs = self.activations[-1]
        # Store model variables for easy access
        variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name)
        self.vars = {var.name: var for var in variables}
        # Build metrics
        self._loss()
        self._accuracy()
        self.opt_op = self.optimizer.minimize(self.loss)
    def predict(self):
        pass
    def _loss(self):
        raise NotImplementedError
    def _accuracy(self):
        raise NotImplementedError
    def save(self, sess=None):
        if not sess:
            raise AttributeError("TensorFlow session not provided.")
        saver = tf.train.Saver(self.vars)
        save_path = saver.save(sess, "tmp/%s.ckpt" % self.name)
        print("Model saved in file: %s" % save_path)
    def load(self, sess=None):
        if not sess:
            raise AttributeError("TensorFlow session not provided.")
        saver = tf.train.Saver(self.vars)
        save_path = "tmp/%s.ckpt" % self.name
        saver.restore(sess, save_path)
        print("Model restored from file: %s" % save_path)
class MLP(Model):
    def __init__(self, placeholders, input_dim, **kwargs):
        super(MLP, self).__init__(**kwargs)
        self.inputs = placeholders['features']
        self.input_dim = input_dim
        # self.input_dim = self.inputs.get_shape().as_list()[1]  # To be supported in future Tensorflow versions
        self.output_dim = placeholders['labels'].get_shape().as_list()[1]
        self.placeholders = placeholders
        self.optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
        self.build()
    def _loss(self):
        # Weight decay loss
        for var in self.layers[0].vars.values():
            self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var)
        # Cross entropy error
        self.loss += masked_softmax_cross_entropy(self.outputs, self.placeholders['labels'],
                                                  self.placeholders['labels_mask'])
    def _accuracy(self):
        self.accuracy = masked_accuracy(self.outputs, self.placeholders['labels'],
                                        self.placeholders['labels_mask'])
    def _build(self):
        self.layers.append(Dense(input_dim=self.input_dim,
                                 output_dim=FLAGS.hidden1,
                                 placeholders=self.placeholders,
                                 act=tf.nn.relu,
                                 dropout=True,
                                 sparse_inputs=True,
                                 logging=self.logging))
        self.layers.append(Dense(input_dim=FLAGS.hidden1,
                                 output_dim=self.output_dim,
                                 placeholders=self.placeholders,
                                 act=lambda x: x,
                                 dropout=True,
                                 logging=self.logging))
    def predict(self):
        return tf.nn.softmax(self.outputs)
class GCN(Model):
    def __init__(self, placeholders, input_dim, **kwargs):
        super(GCN, self).__init__(**kwargs)
        self.inputs = placeholders['features']
        self.input_dim = input_dim
        # self.input_dim = self.inputs.get_shape().as_list()[1]  # To be supported in future Tensorflow versions
        self.output_dim = placeholders['labels'].get_shape().as_list()[1]
        self.placeholders = placeholders
        self.optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
        self.build()
    def _loss(self):
        # Weight decay loss
        for var in self.layers[0].vars.values():
            self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var)
        # Cross entropy error
        self.loss += masked_softmax_cross_entropy(self.outputs, self.placeholders['labels'],
                                                  self.placeholders['labels_mask'])
    def _accuracy(self):
        self.accuracy = masked_accuracy(self.outputs, self.placeholders['labels'],
                                        self.placeholders['labels_mask'])
        self.pred = tf.argmax(self.outputs, 1)
        self.labels = tf.argmax(self.placeholders['labels'], 1)
    def _build(self):
        self.layers.append(GraphConvolution(input_dim=self.input_dim,
                                            output_dim=FLAGS.hidden1,
                                            placeholders=self.placeholders,
                                            act=tf.nn.relu,
                                            dropout=True,
                                            featureless=True,
                                            sparse_inputs=True,
                                            logging=self.logging))
        self.layers.append(GraphConvolution(input_dim=FLAGS.hidden1,
                                            output_dim=self.output_dim,
                                            placeholders=self.placeholders,
                                            act=lambda x: x, #
                                            dropout=True,
                                            logging=self.logging))
    def predict(self):
        return tf.nn.softmax(self.outputs)


论文总结


本文提出了一个简单的GCN在文本分类应用中的有趣应用,并且确实显示了令人欣喜的结果。但是该模型确实具有局限性,因为它具有传导性(通常是GCN的局限性)。在训练过程中,模型将训练数据集中的每个单词和文档,包括测试集。尽管在训练过程中没有对测试集进行任何预测,但是该模型不能应用预测一个全新的文档。这导致了将来可能的工作,即如何将新文档合并到已经构建的图形中。总的来说,我认为本文显示了图神经网络的强大能力及其在我们可以定义和构建某种有用图结构的任何领域中的适用性。

版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。

相关文章
如何设置阿里云服务器安全组?阿里云安全组规则详细解说
阿里云安全组设置详细图文教程(收藏起来) 阿里云服务器安全组设置规则分享,阿里云服务器安全组如何放行端口设置教程。阿里云会要求客户设置安全组,如果不设置,阿里云会指定默认的安全组。那么,这个安全组是什么呢?顾名思义,就是为了服务器安全设置的。安全组其实就是一个虚拟的防火墙,可以让用户从端口、IP的维度来筛选对应服务器的访问者,从而形成一个云上的安全域。
18585 0
阿里云服务器如何登录?阿里云服务器的三种登录方法
购买阿里云ECS云服务器后如何登录?场景不同,阿里云优惠总结大概有三种登录方式: 登录到ECS云服务器控制台 在ECS云服务器控制台用户可以更改密码、更换系.
27731 0
阿里云服务器安全组设置内网互通的方法
虽然0.0.0.0/0使用非常方便,但是发现很多同学使用它来做内网互通,这是有安全风险的,实例有可能会在经典网络被内网IP访问到。下面介绍一下四种安全的内网互联设置方法。 购买前请先:领取阿里云幸运券,有很多优惠,可到下文中领取。
21936 0
阿里云服务器ECS登录用户名是什么?系统不同默认账号也不同
阿里云服务器Windows系统默认用户名administrator,Linux镜像服务器用户名root
15293 0
阿里云服务器端口号设置
阿里云服务器初级使用者可能面临的问题之一. 使用tomcat或者其他服务器软件设置端口号后,比如 一些不是默认的, mysql的 3306, mssql的1433,有时候打不开网页, 原因是没有在ecs安全组去设置这个端口号. 解决: 点击ecs下网络和安全下的安全组 在弹出的安全组中,如果没有就新建安全组,然后点击配置规则 最后如上图点击添加...或快速创建.   have fun!  将编程看作是一门艺术,而不单单是个技术。
19980 0
阿里云服务器怎么设置密码?怎么停机?怎么重启服务器?
如果在创建实例时没有设置密码,或者密码丢失,您可以在控制台上重新设置实例的登录密码。本文仅描述如何在 ECS 管理控制台上修改实例登录密码。
23524 0
腾讯云服务器 设置ngxin + fastdfs +tomcat 开机自启动
在tomcat中新建一个可以启动的 .sh 脚本文件 /usr/local/tomcat7/bin/ export JAVA_HOME=/usr/local/java/jdk7 export PATH=$JAVA_HOME/bin/:$PATH export CLASSPATH=.
14855 0
+关注
400
文章
0
问答
文章排行榜
最热
最新
相关电子书
更多
JS零基础入门教程(上册)
立即下载
性能优化方法论
立即下载
手把手学习日志服务SLS,云启实验室实战指南
立即下载