基于昇腾适配基因表达预测模型Geneformer

简介: Geneformer被广泛应用于疾病建模、治疗靶点发掘、基因网络预测与调控分析、基因功能预测与剂量敏感性分析、单细胞转录组数据集成与标准化、遗传变异解释与GWAS靶点优先排序。该案例既有算法原理,也有手把手的昇腾部署教学,包含细胞分类、基因分类、提取细胞嵌入图、细胞多分类的微调任务

摘要

Geneformer被广泛应用于疾病建模、治疗靶点发掘、基因网络预测与调控分析、基因功能预测与剂量敏感性分析、单细胞转录组数据集成与标准化、遗传变异解释与GWAS靶点优先排序。该案例既有算法原理,也有手把手的昇腾部署教学,包含细胞分类、基因分类、提取细胞嵌入图、细胞多分类的微调任务

1 Geneformer介绍

GeneFormer是一种基于 Transformer 架构的深度学习模型,专为基因表达数据分析而设计。它将基因视为“词汇”,将整个基因组的表达谱视为“句子”,通过自监督学习捕捉基因间的复杂调控关系和生物学背景,在医学研究中展现出强大的应用潜力。借助GeneFormer,研究人员能够更有效地处理和理解大量的基因组数据,从而加速新药开发、疾病治疗等领域的研究进展。在基因序列分析、蛋白质结构预测疾病机制解析和药物发现等领域也具有突出的应用价值。
图片
图1:自监督大规模预训练迁移学习策略示意图

初始自监督大规模预训练,将预训练权重复制到每个微调任务的模型中,添加微调层,并使用有限的特定任务数据对每个下游任务进行微调。通过对可泛化的学习目标进行单次初始自监督大规模预训练,模型获得学习领域的基础知识,然后将其推广到众多不同于预训练学习目标的下游应用中,将知识迁移到新任务。

2 网络结构

图片
图2:Geneformer预训练架构图

预训练的 Geneformer 架构每个单细胞转录组被编码为秩值编码,然后经过六层 Transformer 编码器单元,参数如下:输入大小为 2,048(完全代表 Geneformer-30M 中 93% 的秩值编码),嵌入维度为 256,每层四个注意力头,前馈大小为 512。Geneformer 在 2,048 的输入大小上使用完全密集的自注意力机制。可提取的输出包括上下文基因和细胞嵌入、上下文注意力权重以及上下文预测。

2.1输入层

输入层针对基因表达数据的特性在数据预处理、嵌入表示(Embedding)和位置编码(Positional Encoding)等进行了专门优化。

数据预处理:

  1. 基因嵌入:对基因表达值进行归一化处理,消除不同基因表达水平之间的差异,并对缺失值进行合理填充或插值处理,以确保数据的完整性。
  2. 输入数据:通常包括基因表达矩阵(如单细胞RNA测序数据)和基因序列(如DNA序列)。基因表达矩阵是一个二维矩阵,其中行代表样本,列代表基因,每个元素代表对应基因在该样本中的表达值。基因序列则是由碱基A、T、C、G组成的字符串序列。

嵌入层
将基因表达值或基因序列映射到高维向量空间,以捕捉基因间的复杂关系,便于后续模型处理序列结构。维度设置需要根据具体任务和计算资源进行权衡,过低的维度可能导致信息丢失,而过高会增加计算复杂度。此外,嵌入层通常通过反向传播进行训练,使模型能够自动学习最优的基因嵌入表示,从而更好地适应任务需求。

位置编码
用于提供基因序列中各碱基的位置信息,帮助模型理解基因序列中碱基的顺序关系和位置依赖性,对于分析基因序列的功能和结构至关重要。

2.2Transformer层

GeneFormer的核心由多个 Transformer 层堆叠而来。通过多头自注意力、残差连接和前馈神经网络,从高维基因表达数据中提取复杂的调控模式。在保持标准的Transformer 结构的同时,针对基因表达数据的特性(高维度、稀疏性、基因共表达模式)进行了优化,使模型能够有效捕捉基因间的功能关联,为下游任务(微调)提供强有力的表征。

多头注意力
并行使用多个注意力头,每个头学习不同的交互模式,同时计算多组注意力权重,捕捉基因间的全局依赖关系(如协同表达的基因网络)。通过计算查询(Query)、键(Key)和值(Value)之间的点积来确定权重,并通过 Softmax 函数进行归一化,且总和为1。
图片

图片
将输入拆分为h个头,每个头单独计算后拼接。

前馈神经网络
由两层全连接层和激活函数组成,每个多头注意力层后接一个前馈神经网络层,对注意力层的输出进行非线性映射增强非线性表达能力,用于学习并保存知识。
图片
稀疏注意力
基因表达数据中,大部分基因表达值为0,可能采用局部稀疏注意力以降低计算开销。
图片
相对位置编码
由于基因在序列中的物理位置可能无关紧要,Geneformer 采用相对位置编码,仅编码基因间的相对顺序或距离,增强对基因序列位置的敏感性。
图片
i, j为基因在序列中的位置,k为最大相对距离。

层归一化与残差连接
层归一化稳定单细胞数据的高变异表达分布,残差连接保留原始基因表达信息,缓解梯度消失,加速收敛。
图片
μ和σ分别为样本内均值和方差,γ和β分别为可学习的缩放和平移参数
图片

2.3输出层

经过transformer层之后,张量被传入输出层,但Geneformer输出层的设计根据具体任务(如基因表达预测、分类或自监督预训练)有所不同,主要操作通常包括以下几个关键步骤:

线性变换: 使用全连接层,将Transformer最后一层输出的隐藏状态映射到目标维度(如基因数量或类别数)。
图片
是transformer最后一层的输出向量, 是基因表达预测的权重矩阵, 为偏置项。

激活函数: 根据任务需求不同调整使用的激活函数,回归任务可能使用ReLU或Softplus确保输出非负,分类任务使用Softmax(多分类)或Sigmoid(二分类)输出概率分布,对于线性输出,则没有激活函数。
图片
损失计算: 对于回归任务,使用均方误差(MSE)或负对数似然。分类任务,交叉熵损失。自监督任务(掩码基因预测),使用对比损失或遮蔽语言建模(MLM)类似的损失。
图片
是被遮蔽基因的秩编码值(如基因表达量在细胞内的排序分位数)细胞分类任务中的损失计算,交叉熵损失
图片
基因扰动预测,对比损失
图片
其中, 是基因敲除后的细胞表达谱,通过对比学习强化扰动前后的表达差异。

3 微调介绍

GeneFormer 先在大规模单细胞数据上预训练,结合特定任务的需求和数据特点,灵活选择冻结策略、调整输出头、引入适配器或领域特定模块。通过平衡预训练知识的保留与任务适配,高效实现模型优化。

网络结构的微调操作:

  1. 根据具体的下游任务,确定输入输出格式。即指定数据集。在输入层将数据预处理为与 GeneFormer 兼容的格式,加载预训练的 GeneFormer 权重。
  2. 选择冻结一定数目的transformer层,但不会全部冻结,会保留几层用于保留预训练模型的底层知识(如基因共现模式、 基础序列特征),防止小数据过拟合。
  3. 在预训练模型的基础上额外增加一个transformer层,用于学习新的知识。并在每一层插入小型适配器模块,保持预训练权重冻结,仅训练适配器参数,用于减少参数更新量,适用于小样本微调。
  4. 在输出层,也会根据具体的下游任务进行调整,仅训练最后一层transformer层及输出头。对于分类任务:替换最后的全局平均池化层 + 全连接层。回归任务:调整输出层为线性回归头。生成任务:添加解码器。

4 实验准备

4.1设备&组件

机器:
Atlas 800T A2
组件:
hdk:24.1.rc3
图片
cann:8.0.RC3
图片
python:3.10.16
图片
torch:2.1.0torch:2.1.0.post8
图片

4.2安装LFS

git lfs install

4.3下载源码

git clone https://huggingface.co/ctheodoris/Geneformer

4.4下载数据集

git clone https://gitee.com/hf-datasets/Genecorpus-30M.git
图片

4.5安装环境

requirements.txt里面torch的版本>=2.0.1即可,这里选用2.1.0版本的torch。

cd Geneformer
vi requirements.txt # 将torch>=2.0.1修改为torch==2.1.0。再:wq保存退出。
pip install .

4.6安装torch-npu

4.6.1下载

wget https://gitee.com/ascend/pytorch/releases/download/v6.0.rc3-pytorch2.1.0/torch_npu-2.1.0.post8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl

4.6.2安装

pip3 install torch_npu-2.1.0.post8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl

4.6.3验证npu是否可用

numpy报错,需降低至1.x版本
图片
图片
更换完numpy版本之后,再次验证
图片

5 微调

5.1微调1:细胞分类

5.1.1数据集&权重任务任务:

cell_classification
数据集:human_dcm_hcm_nf.dataset
预训练权重:gf-6L-30M-i2048

5.1.2新建微调脚本

cd /Geneformer/examples
mkdir cell_classification.py
vi cell_classification.py

将cell_classification.ipynb的代码复制过来。需注意修改权重路径和数据集路径。导入os包,将第10行的!mkdir $output_dir修改为os.makedirs(output_dir, exist_ok=True)
图片

5.1.3修改评估模型脚本

vi /root/miniconda3/envs/Genecorpus_py310/lib/python3.10/site-packages/geneformer/
evaluation_utils.py导入torch_npu包,替换相关cuda的api
图片

5.1.4微调前source cann

source /usr/local/Ascend/ascend-toolkit/set_env.sh

5.1.5开始微调

图片
再开一个窗口,命令行输入 npu-smi info查看显存占用率
图片
图片

5.1.6评估模型时报错

图片

vi /root/miniconda3/envs/Genecorpus_py310/lib/python3.10/site-packages/geneformer/evaluation_utils.py
# 在第86行classifier_predict函数内添加
device = torch.device('npu' if torch_npu.npu.is_available() else 'cpu')
# 将119、120、121三行中的.to(“cuda”)修改为.to(device)

图片
再重新运行
图片
输出精度0.9542330129066371
图片
输出文件
图片
混淆矩阵
图片
评估微调模型的预测结果
图片

5.2微调2:基因分类

5.2.1数据集&权重文件

任务:gene_classification
数据集:gc-30M_sample50k.dataset
预训练权重:gf-6L-30M-i2048

5.2.2新建微调脚本

cd /Geneformer/examples
touch gene_classification.py
vi gene_classification.py

将gene_classification.ipynb的代码复制过来。需注意修改权重路径和数据集路径
图片

5.2.3开始微调

图片
图片

5.2.4输出文件

图片

5.3微调3:绘制细胞嵌入图

5.3.1数据集&权重文件

任务:extract_and_plot_cell_embeddings
数据集:human_dcm_hcm_nf.dataset
预训练权重:gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224

5.3.2新建微调脚本

cd /Geneformer/examples
touch extract_and_plot_cell_embeddings.py
vi extract_and_plot_cell_embeddings.py

将extract_and_plot_cell_embeddings.ipynb的代码复制过来。需注意修改权重路径和数据集路径
图片

5.3.3开始微调

图片

5.3.4输出文件

图片

5.3.5细胞嵌入UMAP图

图片

5.3.6细胞嵌入heapmap图

图片

5.4微调4:多任务细胞分类

5.4.1数据集&权重文件

任务:multitask_cell_classification
数据集:human_dcm_hcm_nf.dataset
预训练权重:gf-6L-30M-i2048 5.4.2新建微调脚本

cd /Geneformer/examples
touch multitask_cell_classification.py
vi multitask_cell_classification.py

将multitask_cell_classification.ipynb的代码复制过来。需注意修改权重路径、数据集路径以及token_dictionary路径。
图片

5.4.3微调过程

图片

5.4.4输出

图片

6 参考文献

Theodoris, C. V., Xiao, L., Chopra, A., Chaffin, M. D., Al Sayed, Z. R., Hill, M. C., ... & Ellinor, P. T. (2023). Transfer learning enables predictions in network biology. Nature, 618(7965), 616-624. https://doi.org/10.1038/s41586-023-06139-9

相关文章
|
移动开发 JavaScript 前端开发
四种方式解决页面国际化问题——步骤详解
四种方式解决页面国际化问题——步骤详解
459 0
|
存储 人工智能 Java
【图文详解】基于Spring AI的旅游大师应用开发、多轮对话、文件持久化、拦截器实现
【图文详解】基于Spring AI的旅游大师应用开发、多轮对话、文件持久化、拦截器实现
602 0
|
4月前
|
缓存 API Android开发
【HarmonyOS next】ArkUI-X新闻热搜聚合App【进阶】
本项目基于ArkUI-X框架,将鸿蒙(HarmonyOS)下的新闻热搜聚合App无缝迁移至iOS平台。采用ArkUI开发,结合@kit.NetworkKit实现网络请求,利用@ObservedV2与@Trace装饰器进行数据绑定,适配iOS界面布局与权限配置,完成跨平台热榜应用构建。
118 0
|
11月前
|
数据采集 自然语言处理 搜索推荐
基于qwen2.5的长文本解析、数据预测与趋势分析、代码生成能力赋能esg报告分析
Qwen2.5是一款强大的生成式预训练语言模型,擅长自然语言理解和生成,支持长文本解析、数据预测、代码生成等复杂任务。Qwen-Long作为其变体,专为长上下文场景优化,适用于大型文档处理、知识图谱构建等。Qwen2.5在ESG报告解析、多Agent协作、数学模型生成等方面表现出色,提供灵活且高效的解决方案。
1043 49
|
10月前
|
存储 人工智能 监控
【AI系统】推理系统架构
本文深入探讨了AI推理系统架构,特别是以NVIDIA Triton Inference Server为核心,涵盖推理、部署、服务化三大环节。Triton通过高性能、可扩展、多框架支持等特点,提供了一站式的模型服务解决方案。文章还介绍了模型预编排、推理引擎、返回与监控等功能,以及自定义Backend开发和模型生命周期管理的最佳实践,如金丝雀发布和回滚策略,旨在帮助构建高效、可靠的AI应用。
778 15
|
小程序 前端开发 API
🍁商城类小程序实战(一):需求分析及开发前准备
🍁商城类小程序实战(一):需求分析及开发前准备
1576 2
🍁商城类小程序实战(一):需求分析及开发前准备
|
12月前
|
机器学习/深度学习 编解码 算法
轻量级网络论文精度笔记(三):《Searching for MobileNetV3》
MobileNetV3是谷歌为移动设备优化的神经网络模型,通过神经架构搜索和新设计计算块提升效率和精度。它引入了h-swish激活函数和高效的分割解码器LR-ASPP,实现了移动端分类、检测和分割的最新SOTA成果。大模型在ImageNet分类上比MobileNetV2更准确,延迟降低20%;小模型准确度提升,延迟相当。
366 1
轻量级网络论文精度笔记(三):《Searching for MobileNetV3》
|
存储 固态存储 Linux
在Linux中,ext4文件系统有何特性?如何检查文件系统的完整性?
在Linux中,ext4文件系统有何特性?如何检查文件系统的完整性?
|
Linux API 开发工具
使用Pygame库进行2D游戏开发的优缺点有哪些?
【6月更文挑战第10天】使用Pygame库进行2D游戏开发的优缺点有哪些?
301 1