PGL图学习之基于GNN模型新冠疫苗任务[系列九]
项目链接:https://aistudio.baidu.com/aistudio/projectdetail/5123296?contributionType=1
# 加载一些需要用到的模块,设置随机数
import json
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
from utils.config import prepare_config, make_dir
from utils.logger import prepare_logger, log_to_file
from data_parser import GraphParser
seed = 123
np.random.seed(seed)
random.seed(seed)
数据EDA
# https://www.kaggle.com/c/stanford-covid-vaccine/data
# 加载训练用的数据
df = pd.read_json('../data/data179441/train.json', lines=True)
# 查看一下数据集的内容
sample = df.loc[0]
print(sample)
index 400
id id_2a7a4496f
sequence GGAAAGCCCGCGGCGCCGGGCGCCGCGGCCGCCCAGGCCGCCCGGC...
structure .....(((...)))((((((((((((((((((((.((((....)))...
predicted_loop_type EEEEESSSHHHSSSSSSSSSSSSSSSSSSSSSSSISSSSHHHHSSS...
signal_to_noise 0
SN_filter 0
seq_length 107
seq_scored 68
reactivity_error [146151.225, 146151.225, 146151.225, 146151.22...
deg_error_Mg_pH10 [104235.1742, 104235.1742, 104235.1742, 104235...
deg_error_pH10 [222620.9531, 222620.9531, 222620.9531, 222620...
deg_error_Mg_50C [171525.3217, 171525.3217, 171525.3217, 171525...
deg_error_50C [191738.0886, 191738.0886, 191738.0886, 191738...
reactivity [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_Mg_pH10 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_pH10 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_Mg_50C [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_50C [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
Name: 0, dtype: object
例如 deg_50C、deg_Mg_50C 这样的值全为0的行,就是我们需要预测的。
structure一行,数据中的括号是为了构成边用的。
本案例要预测RNA序列不同位置的降解速率,训练数据中提供了多个ground值,标签包括以下几项:reactivity, deg_Mg_pH10, and deg_Mg_50
reactivity - (1x68 vector 训练集,1x91测试集) 一个浮点数数组,与seq_scores有相同的长度,是前68个碱基的反应活性值,按顺序表示,用于确定RNA样本可能的二级结构。
deg_Mg_pH10 - (训练集 1x68向量,1x91测试集)一个浮点数数组,与seq_scores有相同的长度,是前68个碱基的反应活性值,按顺序表示,用于确定在高pH (pH 10)下的降解可能性。
deg_Mg_50 - (训练集 1x68向量,1x91测试集)一个浮点数数组,与seq_scores有相同的长度,是前68个碱基的反应活性值,按顺序表示,用于确定在高温(50摄氏度)下的降解可能性。
# 利用GraphParser构造图结构的数据
args = prepare_config("./config.yaml", isCreate=False, isSave=False)
parser = GraphParser(args) # GraphParser类来自data_parser.py
gdata = parser.parse(sample) # GraphParser里最主要的函数就是parse(self, sample)
{'nfeat': array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 1., 0., ..., 0., 0., 0.],
...,
[1., 0., 0., ..., 0., 0., 0.],
[1., 0., 0., ..., 0., 0., 0.],
[1., 0., 0., ..., 0., 0., 0.]], dtype=float32),
'edges': array([[ 0, 1],
[ 1, 0],
[ 1, 2],
...,
[142, 105],
[106, 142],
[142, 106]]),
'efeat': array([[ 0., 0., 0., 1., 1.],
[ 0., 0., 0., -1., 1.],
[ 0., 0., 0., 1., 1.],
...,
[ 0., 1., 0., 0., 0.],
[ 0., 1., 0., 0., 0.],
[ 0., 1., 0., 0., 0.]], dtype=float32),
'labels': array([[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
...,
[ 0. , 0.9213, 0. ],
[ 6.8894, 3.5097, 5.7754],
[ 0. , 1.8426, 6.0642],
...,
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ]], dtype=float32),
'mask': array([[ True],
[ True],
......
[False]])}
nfeat —— 节点特征
edges —— 边
efeat —— 边特征
labels —— 节点标签有三种,所以这可以看成是一个多分类任务
图数据可视化
# 图数据可视化
fig = plt.figure(figsize=(24, 12))
nx_G = nx.Graph()
nx_G.add_nodes_from([i for i in range(len(gdata['nfeat']))])
nx_G.add_edges_from(gdata['edges'])
node_color = ['g' for _ in range(sample['seq_length'])] + \
['y' for _ in range(len(gdata['nfeat']) - sample['seq_length'])]
options = {
"node_color": node_color,
}
pos = nx.spring_layout(nx_G, iterations=400, k=0.2)
nx.draw(nx_G, pos, **options)
plt.show()
模型训练&预测
# 我们在 layer.py 里定义了一个新的 gnn 模型(my_gnn),消息传递的过程中加入了边的特征(edge_feat)
# 然后修改 model.py 里的 GNNModel
# 使用修改后的模型,运行 main.py。为节省时间,设置 epochs = 100
!python main.py --config config.yaml
结果返回的是 MCRMSE 和 loss
{'MCRMSE': 0.5496759, 'loss': 0.3025484172316889}
[DEBUG] 2022-11-25 17:50:42,468 [ trainer.py: 66]: {'MCRMSE': 0.5496759, 'loss': 0.3025484172316889}
[DEBUG] 2022-11-25 17:50:42,468 [ trainer.py: 73]: write to tensorboard ../checkpoints/covid19/eval_history/eval
[DEBUG] 2022-11-25 17:50:42,469 [ trainer.py: 73]: write to tensorboard ../checkpoints/covid19/eval_history/eval
[INFO] 2022-11-25 17:50:42,469 [ trainer.py: 76]: [Eval:eval]:MCRMSE:0.5496758818626404 loss:0.3025484172316889
[INFO] 2022-11-25 17:50:42,602 [monitored_executor.py: 606]: ********** Stop Loop ************
[DEBUG] 2022-11-25 17:50:42,607 [monitored_executor.py: 199]: saving step 12500 to ../checkpoints/covid19/model_12500
!python main.py --mode infer