背景介绍
DGL (Deep Graph Learning) 和 PyG (Pytorch Geometric) 是两个主流的图神经网络库,它们在API设计和底层实现上有一定差异,在不同场景下,研究人员会使用不同的依赖库,昇腾NPU对PyG图机器学习库的支持亲和度更高,因此有些时候需要做DGL接口的PyG替换。
SE3Transformer在RFdiffusion蛋白质设计模型中(GitHub - RosettaCommons/RFdiffusion: Code for running RFdiffusion)作为核心组件,负责处理蛋白质结构的几何信息。其架构基于图神经网络,通过SE(3)等变性实现对三维旋转和平移的不变性特征提取。本系列以RFDiffusion模型中的SE3Transformer为例,讲解如何将DGL中的接口替换为PyG实现。
在本文中,主要展示消息传递接口的PyG替换。
消息传递接口
一、边-节点消息传递 (EdgeSoftmax + Aggregation)
位置:
rfdiffusion/modules/equivariant_attention/modules.py 中的 TransformerLayer
输入:
节点特征: x , 形状为(N, F)
边特征: edge_attr , 形状为(E, F')
图结构: graph
输出:
更新的节点特征: 形状为(N, F_out)
DGL函数:
dgl.nn.EdgeSoftmax:对边特征进行归一化
dgl.function.copy_edge:复制边特征
dgl.function.sum:聚合消息
数学逻辑:
- 计算注意力分数\( a_{ij}=\mathrm{softmax}j(e{ij}) \)
- 消息聚合:\( hi^{\prime}=\sum{j\in\mathcal{N}(i)}a_{ij}\cdot h_j \)
PyG实现:
def edge_softmax_aggregation(x, edge_index, edge_attr):
# 计算源节点和目标节点索引
src, dst = edge_index
# 计算边softmax
exp_edge_attr = torch.exp(edge_attr)
# 按目标节点归一化
node_degree = scatter_add(exp_edge_attr, dst, dim=0, dim_size=x.size(0)) norm = node_degree[dst].clamp(min=1e-6)
norm_edge_attr = exp_edge_attr / norm
# 消息传递
message = norm_edge_attr * x[src]
# 聚合
out = scatter_add(message, dst, dim=0, dim_size=x.size(0))
return out
二、矢量特征消息传递
位置:
rfdiffusion/modules/equivariant_attention/modules.py 中的 AttentionBlockSE3
输入:
标量特征: feat_scalar , 形状为(N, F_s)
矢量特征: feat_vector , 形状为(N, F_v, 3)
图结构: graph
输出:
更新的标量和矢量特征
DGL函数:
dgl.nn.EdgeSoftmax:边特征softmax
g.send_and_recv:消息传递与聚合
数学逻辑:
1.\( m{ij}=f\mathrm{att}(h_i^s,h_j^s,h_i^v,h_j^v) \)
2.矢量特征旋转\( hj^v\cdot R{ij} \)
PyG实现关键点:
需要自定义消息传递函数实现等变性旋转操作处理批处理边索引