图神经网络15-Text-Level-GNN:基于文本级GNN的文本分类模型

简介: 图神经网络15-Text-Level-GNN:基于文本级GNN的文本分类模型

论文题目:Text Level Graph Neural Network for Text Classification

论文地址:https://arxiv.org/pdf/1910.02356.pdf

论文代码:https://github.com/yenhao/text-level-gnn

发表时间:2019


论文简介与动机


1)TextGCN为整个数据集/语料库构建一个异构图(包括(待分类)文档节点和单词节点),边的权重是固定的(单词节点间的边权重是两个单词的PMI,文档-单词节点间的边权重是TF-IDF),固定权重限制了边的表达能力,而且为了获取一个全局表示不得不使用一个非常大的连接窗口。因此,构建的图非常大,而且边非常多,模型由很大的内存消耗。

2)上篇博客也提到了,TextGCN这种类型的模型,无法为新样本(文本)进行分类(在线测试),因为图的架构和参数依赖于语料库/数据集,训练结束后就不能再修改了。(除非将新文本加入到语料库中,更新图的结构,重新训练......一般不会这样做,总之该类模型不能为新文本进行分类)


本篇论文提出了一个新的基于GNN的模型来做文本分类,解决了上述两个问题:

1)为每个输入文本/数据(text-level)都单独构建一个图,文本中的单词作为节点;而不是给整个语料库/数据集(corpus-level)构建一个大图(每个文本和单词作为节点)。在每个文本中,使用一个非常小的滑动窗口,文本中的每个单词只与其左右的p个词有边相连(包括自己,自连接),而不是所有单词节点全连接。

2)相同单词节点的表示以及相同单词对之间边的权重全局(数据集/语料库中的所有文本/数据)共享,通过文本级别图的消息传播机制进行更新。

这样就可以消除单个输入文本和整个语料库/数据集的依赖负担,支持在线测试(新文本测试);而且上下文窗口更小,边数更少,内存消耗更小。


Text-Level-GNN模型


构建文本图


对于给定的一个包含l个词的文本记为,其中代表文本中第个单词的表示,初始化一个全局共享的词嵌入矩阵(使用预训练词向量初始化),每个单词/节点的初始表示从该嵌入矩阵中查询,嵌入矩阵作为模型参数在训练过程中更新。

为每个输入文本/数据构建一个图,把文本中的单词看作是节点,每个单词和它左右相邻的个单词有边相连(包括自己,自连接)。输入文本的图表示为:

其中N和E是文本图的节点集和边集,每个单词节点的表示,以及单词节点间边的权重分别来自两个全局共享矩阵(模型参数,训练过程中更新)。此外,对于训练集中出现次数少于k(k=2)次的边(词对)均匀地映射到一个"公共边",使得参数充分学习。


49.png


如上图所示:一个文本Text Level Graph为一个单独的文本“he is very proud of you.”。为了显示方便,在这个图中,为节点“very”(节点和边用红色表示)设置,为其他节点(用蓝色表示)设置。在实际情况下,会话期间的值是唯一的。图中的所有参数都来自图底部显示的全局共享表示矩阵。

与以往构建图的方法相比,该方法可以极大地减少图的节点和边的规模。这意味着文本级图形可以消耗更少的GPU内存。


消息传递机制


卷积可以从局部特征中提取信息。在图域中,卷积是通过频谱方法或非频谱方法实现的。在本文中,一种称为消息传递机制(MPM)的非频谱方法被用于卷积。MPM首先从相邻节点收集信息,并根据其原始表示形式和所收集的信息来更新其表示形式,其定义为:

50.png


其中是节点从其邻居接收到的消息;是一种归约函数,它将每个维上的最大值组合起来以形成一个新的向量作为输出。代表原始文本中的最近个单词的节点;是从节点到节点的边缘权重,它可以训练时更新;代表节点n先前的表示向量。节点n的可训练的变量,指示应该保留多少的信息。代表节点更新后的表示。

MPM使节点的表示受到邻域的影响,这意味着表示可以从上下文中获取信息。因此,即使对于一词多义,上下文中的精确含义也可以通过来自邻居的加权信息的影响来确定。此外,文本级图的参数取自全局共享矩阵,这意味着表示形式也可以像其他基于图的模型一样带来全局信息。

最后,使用文本中所有节点的表示来预测文本的标签:

其中是将向量映射到输出空间的矩阵,是文本的节点集,是偏差。

训练的目的是最小化真实标签和预测标签之间的交叉熵损失:

,其中是真实标签的one-hot向量表示。


实验结果


不同模型的对比实验


数据集采用了R8,R52和Ohsumed。R8和R52都是路透社21578数据集的子集。


51.png


p值影响


52.png


消融实验


(1)取消边之间的权重,性能变差,说明为边设置权重较好。

(2)mean取代max

(3)去掉预训练词嵌入


53.png


核心代码


获取邻居词:https://github.com/yenhao/text-level-gnn/blob/master/utils.py

def get_word_neighbors_mp(text_tokens:list, neighbor_distance:int) :
    print("\tGet word's neighbors")
    with mp.Pool(mp.cpu_count()) as p:
        return p.starmap(get_word_neighbor, map(lambda tokens: (tokens, neighbor_distance), text_tokens))
def get_word_neighbor(text_tokens: list, neighbor_distance: int) :
    """Get word token's adjacency neighbors with distance : neighbor_distance
    Args:
        text_tokens (list): A list of the tokens of sentences/texts from dataset.
        neighbor_distance (int): The adjacency distance to consider as a neighbor.
    Returns:
        list: A nested list with 2 dimensions, which is a list of neighbor word tokens (2nd dim) for all tokens (1nd dim)
    """
    text_len = len(text_tokens)
    edge_neighbors = []
    for w_idx in range(text_len):
        skip_neighbors = []
        # check before
        for sk_i in range(neighbor_distance):
            before_idx = w_idx -1 - sk_i
            skip_neighbors.append(text_tokens[before_idx] if before_idx > -1 else 0)
        # check after
        for sk_i in range(neighbor_distance):
            after_idx = w_idx +1 +sk_i
            skip_neighbors.append(text_tokens[after_idx] if after_idx < text_len else 0)
        edge_neighbors.append(skip_neighbors)
    return edge_neighbors


TextLevelGNN层:https://github.com/yenhao/text-level-gnn/blob/master/model.py

class TextLevelGNN(nn.Module):
    def __init__(self, num_nodes, node_feature_dim, class_num, embeddings=0, embedding_fix=False):
        super(TextLevelGNN, self).__init__()
        if type(embeddings) != int:
            print("\tConstruct pretrained embeddings")
            self.node_embedding = nn.Embedding.from_pretrained(embeddings, freeze=embedding_fix, padding_idx=0)
        else:
            self.node_embedding = nn.Embedding(num_nodes, node_feature_dim, padding_idx = 0)
        # self.edge_weights = nn.Embedding((num_nodes-1) * (num_nodes-1) + 1, 1, padding_idx=0) # +1 is padding
        self.edge_weights = nn.Embedding(num_nodes * num_nodes, 1) # +1 is padding
        self.node_weights = nn.Embedding(num_nodes, 1, padding_idx=0) # Nn, node weight for itself
        self.fc = nn.Sequential(
            nn.Linear(node_feature_dim, class_num, bias=True),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Softmax(dim=1)
        )
    def forward(self, X, NX, EW):
        """
        INPUT:
        -------
        X  [tensor](batch, sentence_maxlen)               : Nodes of a sentence
        NX [tensor](batch, sentence_maxlen, neighbor_distance*2): Neighbor nodes of each nodes in X
        EW [tensor](batch, sentence_maxlen, neighbor_distance*2): Neighbor weights of each nodes in X
        OUTPUT:
        -------
        y  [list] : Predicted Probabilities of each classes
        """
        ## Neighbor
        # Neighbor Messages (Mn)
        Mn = self.node_embedding(NX) # (BATCH, SEQ_LEN, NEIGHBOR_SIZE, EMBED_DIM)
        # EDGE WEIGHTS
        En = self.edge_weights(EW) # (BATCH, SEQ_LEN, NEIGHBOR_SIZE )
        # get representation of Neighbors
        Mn = torch.sum(En * Mn, dim=2) # (BATCH, SEQ_LEN, EMBED_DIM)
        # Self Features (Rn)
        Rn = self.node_embedding(X) # (BATCH, SEQ_LEN, EMBED_DIM)
        ## Aggregate information from neighbor
        # get self node weight (Nn)
        Nn = self.node_weights(X)
        Rn = (1 - Nn) * Mn + Nn * Rn
        # Aggragate node features for sentence
        X = Rn.sum(dim=1)
        y = self.fc(X)
        return y


结论


本文提出了一个新的基于图的文本分类模型,该模型使用文本级图而不是整个语料库的单个图。实验结果表明,我们的模型达到了最先进的性能,并且在内存消耗方面具有显着优势。

相关文章
|
4天前
|
网络协议 安全 网络安全
探索网络模型与协议:从OSI到HTTPs的原理解析
OSI七层网络模型和TCP/IP四层模型是理解和设计计算机网络的框架。OSI模型包括物理层、数据链路层、网络层、传输层、会话层、表示层和应用层,而TCP/IP模型则简化为链路层、网络层、传输层和 HTTPS协议基于HTTP并通过TLS/SSL加密数据,确保安全传输。其连接过程涉及TCP三次握手、SSL证书验证、对称密钥交换等步骤,以保障通信的安全性和完整性。数字信封技术使用非对称加密和数字证书确保数据的机密性和身份认证。 浏览器通过Https访问网站的过程包括输入网址、DNS解析、建立TCP连接、发送HTTPS请求、接收响应、验证证书和解析网页内容等步骤,确保用户与服务器之间的安全通信。
30 1
|
9天前
|
监控 安全 BI
什么是零信任模型?如何实施以保证网络安全?
随着数字化转型,网络边界不断变化,组织需采用新的安全方法。零信任基于“永不信任,永远验证”原则,强调无论内外部,任何用户、设备或网络都不可信任。该模型包括微分段、多因素身份验证、单点登录、最小特权原则、持续监控和审核用户活动、监控设备等核心准则,以实现强大的网络安全态势。
|
28天前
|
机器学习/深度学习 自然语言处理 数据可视化
【由浅到深】从神经网络原理、Transformer模型演进、到代码工程实现
阅读这个文章可能的收获:理解AI、看懂模型和代码、能够自己搭建模型用于实际任务。
108 11
|
2月前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于BP神经网络的苦瓜生长含水量预测模型matlab仿真
本项目展示了基于BP神经网络的苦瓜生长含水量预测模型,通过温度(T)、风速(v)、模型厚度(h)等输入特征,预测苦瓜的含水量。采用Matlab2022a开发,核心代码附带中文注释及操作视频。模型利用BP神经网络的非线性映射能力,对试验数据进行训练,实现对未知样本含水量变化规律的预测,为干燥过程的理论研究提供支持。
|
1月前
|
存储 网络协议 安全
30 道初级网络工程师面试题,涵盖 OSI 模型、TCP/IP 协议栈、IP 地址、子网掩码、VLAN、STP、DHCP、DNS、防火墙、NAT、VPN 等基础知识和技术,帮助小白们充分准备面试,顺利踏入职场
本文精选了 30 道初级网络工程师面试题,涵盖 OSI 模型、TCP/IP 协议栈、IP 地址、子网掩码、VLAN、STP、DHCP、DNS、防火墙、NAT、VPN 等基础知识和技术,帮助小白们充分准备面试,顺利踏入职场。
89 2
|
1月前
|
运维 网络协议 算法
7 层 OSI 参考模型:详解网络通信的层次结构
7 层 OSI 参考模型:详解网络通信的层次结构
186 1
|
2月前
|
网络协议 前端开发 Java
网络协议与IO模型
网络协议与IO模型
144 4
网络协议与IO模型
|
2月前
|
机器学习/深度学习 网络架构 计算机视觉
目标检测笔记(一):不同模型的网络架构介绍和代码
这篇文章介绍了ShuffleNetV2网络架构及其代码实现,包括模型结构、代码细节和不同版本的模型。ShuffleNetV2是一个高效的卷积神经网络,适用于深度学习中的目标检测任务。
109 1
目标检测笔记(一):不同模型的网络架构介绍和代码
|
1月前
|
网络协议 算法 网络性能优化
计算机网络常见面试题(一):TCP/IP五层模型、TCP三次握手、四次挥手,TCP传输可靠性保障、ARQ协议
计算机网络常见面试题(一):TCP/IP五层模型、应用层常见的协议、TCP与UDP的区别,TCP三次握手、四次挥手,TCP传输可靠性保障、ARQ协议、ARP协议
|
1月前
|
机器学习/深度学习 人工智能 算法
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
车辆车型识别,使用Python作为主要编程语言,通过收集多种车辆车型图像数据集,然后基于TensorFlow搭建卷积网络算法模型,并对数据集进行训练,最后得到一个识别精度较高的模型文件。再基于Django搭建web网页端操作界面,实现用户上传一张车辆图片识别其类型。
87 0
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型

热门文章

最新文章