DEL编码新药预测的多种机器学习模型对比

简介: 数据集描述数据集中每个分子具有三个构建块。该数据集用于表示分子的三个构建块是否能够与蛋白质相结合,如果能够结合标记为binds为1,否则binds为0.格式描述如下:• id- 我们用来识别分子结合靶标对的独特example_id。• buildingblock1_smiles- 在SMILES中,第一个构建块的结构• buildingblock2_smiles- 在SMILES中,第二个构建块的结构• buildingblock3_smiles- 在SMILES中,第三个构建块的结构• molecule_smiles- 完全组装的分子的结构,在SMILES中。这包括三个构建单元

数据集描述

数据集中每个分子具有三个构建块。该数据集用于表示分子的三个构建块是否能够与蛋白质相结合,如果能够结合标记为binds为1,否则binds为0.

格式描述如下:

  • id- 我们用来识别分子结合靶标对的独特example_id。
  • buildingblock1_smiles- 在SMILES中,第一个构建块的结构
  • buildingblock2_smiles- 在SMILES中,第二个构建块的结构
  • buildingblock3_smiles- 在SMILES中,第三个构建块的结构
  • molecule_smiles- 完全组装的分子的结构,在SMILES中。这包括三个构建单元和三嗪核心。请注意,我们使用 a 作为 DNA 接头的替代物。[Dy]
  • protein_name- 蛋白质靶标名称
  • binds- 目标列。分子是否与蛋白质结合的二元类标记。不适用于测试集。

工具库描述

  • rdkit 用于化学信息学的开源工具包,提供了丰富的功能来支持药物涉及、生物活性预测、化学反应预测和化学数据处理等领域。本案例中主要用于计算分子指纹。
  • duckdb 开源嵌入式分析型数据库管理系统,转为数据分析和在线分析处理(OLAP)二涉及。本案例主要用于列式存储数据分析。
  • PySMILES 用于处理SMILES格式的分子表示。

算法详解

!pip install duckdb
!pip install pysmiles
!pip install rdkit

数据加载

#导入系统库
import re
import os
import unicodedata
import itertools

#导入数据处理库
import pandas as pd
import numpy as np
import pandas

#导入数据库处理库
import duckdb

#导入数据虚拟化库
import pysmiles
import plotly
import seaborn as sns
import matplotlib.pylab as pl
import matplotlib.pylab as m
import matplotlib.pylab as mpk
import matplotlib.pyplot as plt
import plotly.express as px
from matplotlib import pyplot as plt
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
from rdkit import RDLogger
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem.Draw import rdMolDraw2D

#python版本
from IPython.display import SVG
IPythonConsole.ipthon_userSVG=True

#配置图像的绘制风格
sns.set_theme(style='whitegrid')
palette='viridis'

#从Parquet导入数据库,读取数据
data_train = '/input/train.parquet'
test_path = '/input/test.parquet'

#连接duckdb数据库
con = duckdb.connect()

#查询出bind成功和bind失败的数据,打乱了随机排序,取前30000条。以dataframe的格式返回
data = con.query(f"""(SELECT * FROM parquet_scan('{data_train}') 
WHERE binds = 0
ORDER BY random()
LIMIT 30000)
UNION ALL
(SELECT * FROM parquet_scan('{data_train}')
WHERE binds = 1
ORDER BY random() 
LIMIT 30000)""").df()

#关闭数据库
con.close()

#保存数据到csv文件
data.to_csv('/working/dataset.csv')

数据预处理

在预处理阶段,我们执行几个基本步骤来准备用于分析的数据。首先,应用数据清洗技术去除重复项和处理缺失值;然后,根据数据的性质,使用适当的编码方法,如one-hot编码或标签编码,将分类变量转换为数值变量。此外,我们将数值变量标准化或标准化,以确保它们处于相同的尺度上,这对许多机器学习算法至关重要。这些预处理步骤确保了数据格式适合分析模型,提高了后续分析的准确性和效率。

#应用rdkit将分子式转换为rdkit分子对象
data['molecule'] = data['molecule_smiles'].apply(Chem.MolFromSmiles)

#创建分子指纹位图函数
def modl(molecule_data, radius=2, bits=1024):
    if molecule_data is None:
        return None
    return list(AllChem.GetMorganFingerprintAsBitVect(molecule_data, radius, nBits=bits))

#根据分子对象和位图函数生成指纹
data['H1_ecfp'] = data['molecule'].apply(modl)
from sklearn.preprocessing import OneHotEncoder

encoder_onehot = OneHotEncoder(sparse_output=False)
encoder_onehot_fit = encoder_onehot.fit_transform(data['protein_name'].values.reshape(-1,1))
#分子指纹和蛋白质独热编码进行组合,用于创建唯一特征减少分类
X = [ecfp + protein for ecfp, protein in zip(data['H1_ecfp'].tolist(), encoder_onehot_fit.tolist())]
y = data[binds].tolist

这里,我们对两个变量进行了划分:“H1_ecfp”和名为“绑定”的目标变量。这一步对于规范化数据至关重要,确保“H1_ecfp”的值相对于目标变量“绑定”进行缩放。归一化对于避免可能影响各种机器学习算法性能的尺度问题很重要,特别是那些基于距离的算法,如k近邻(KNN)和聚类方法。此外,这种操作可以为“H1_ecfp”和“绑定”之间的比例关系提供有价值的见解,允许更好地解释模型的结果。该部门可以突出数据中可能对预测建模至关重要的隐藏趋势或模式。通过适当的归一化,我们可以提高模型的稳定性和准确性,确保所有变量对学习过程的贡献相等。

模型训练

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

#导入进度生成库
from tqdm import tqdm

from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KneightborsClassifier
from sklearn.tree import DevisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier 
from sklearn.ensemble import AdaBoostClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.naive_bayes import GaussianNB
from lightgbm import LGBMClassifier
from xgboost import XGBClassifier, plot_importance as plot_importance_xgb
from lightgbm import LGBMClassifier, plot_importance as plot_importance_lgbm


#度量标准和模型评估
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score
from sklearn.metrics import roc_curve, auc, confusion_matrix, accuracy_score, classification_report

#机器学习模型
models = {
  #逻辑回归模型
  "Logistic Regression": LogisticRegression(),

  #朴素贝叶斯模型
  "Naive bayes": GaussianNB(),

  #KNN模型
  "KNN": KNeighborsClassifier(),

  #AdaBoost模型(通过迭代弱分类器形成强分类器)
  "Ada Boost": AdaBoostClassifier(),

  #梯度提升模型(通过迭代训练决策树来提供预测准确率)
  "Gradient Boosting Classifier":GradientBoostingClassifier(),

  #决策树模型
  "Decision Tree Classifier" : DecisionTreeClassifier(max_depth=5,
  min_samples_split=2,
  random_state=105),

  #XGBoost 模型(优化的分布式梯度提升库)
  "XGBoost": XGBClassifier(n_estimators=100,
  max_depth=250,
  learning_rate=0.1,
  subsample=0.8,
  colsample_bytree=0.8
  num_class=3,
  random_state=42,
  tree_method='gpu_hist'),

  #LGBM 模型(基于决策树算法的分布式梯度提升框架)
  "LGBM": LGBMClassifier(boosting_type='gbdt',
  bagging_freq=5,
  verbose=0,
  device='gpu',
  num_leaves=31,
  max_depth=250,
  learning_rate=0.1,
  n_estimators=100)
}

#模型训练
for name, model in tqdm(models.items(), desc="traning models", total=len(models)):
  #模型学习
  model.fit(X_train, y_train)

  #通过交叉验证的方式找出最好的参数,折叠10次
  score_training = cross_val_score(model, X_train, y_train, cv=10)

  #使用模型进行预测
  pred_mode = mode.predict(X_test)

  #展示模型进度和结果
  tqdm.write("Model: {} has Accuracy {:.2f}%".format(model.__class__.__name__,round(score_training.mean(), 2) * 100))

  print()
Training models:  12%|█▎        | 1/8 [01:40<11:44, 100.63s/it]
Model: LogisticRegression has Accuracy 87.00%

Training models:  25%|██▌       | 2/8 [03:00<08:50, 88.36s/it] 
Model: GaussianNB has Accuracy 74.00%

Training models:  38%|███▊      | 3/8 [05:12<09:02, 108.45s/it]
Model: KNeighborsClassifier has Accuracy 80.00%

Training models:  50%|█████     | 4/8 [13:33<17:33, 263.47s/it]
Model: AdaBoostClassifier has Accuracy 79.00%

Training models:  62%|██████▎   | 5/8 [43:18<40:36, 812.11s/it]
Model: GradientBoostingClassifier has Accuracy 84.00%

Training models:  75%|███████▌  | 6/8 [44:35<18:44, 562.10s/it]
Model: DecisionTreeClassifier has Accuracy 75.00%

Training models:  88%|████████▊ | 7/8 [52:01<08:44, 524.23s/it]
Model: XGBClassifier has Accuracy 91.00%

[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
1 warning generated.
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
Training models: 100%|██████████| 8/8 [53:43<00:00, 402.98s/it]
[LightGBM] [Warning] bagging_freq is set=5, subsample_freq=0 will be ignored. Current value: bagging_freq=5
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
Model: LGBMClassifier has Accuracy 89.00%

CPU times: user 58min 38s, sys: 1min 6s, total: 59min 44s
Wall time: 53min 47s

为这个项目生成了8个机器学习模型:逻辑回归、朴素贝叶斯、k近邻(KNN)、决策树、AdaBoost、梯度提升、XGBoost和LightGBM。使用特定的数据集对每个模型进行训练和评估,以确定表现最佳的模型。经过评估,LightGBM模型是最有效的,达到了90%的准确率。该模型不仅表现出了最好的准确率,而且在精确度、召回率和F1-score等其他性能指标上也表现出了鲁棒性,表明了其在各种情况下的一致性和可靠性。次优的表现是XGBoost模型,达到了84%的准确率。虽然准确率低于LightGBM,但XGBoost在其他评价指标上也表现出了良好的效果。此外,详细分析了每个模型在不同数据子集上的性能,以验证其泛化性并防止过拟合。基于此分析,LightGBM不仅在精度方面,而且在泛化能力和稳定性方面证明了其优越性。因此,综合考虑所有评估标准,LightGBM模型表现出最高的坚持度和性能,使其成为在此背景下未来实现的最推荐的选择。

auc图

for name, model in models.items():
  #模型训练
  model.fit(X_train, y_train)

  #在test集上进行预测
  y_pred = model.predict(X_test)
  print("Machine Learning Model:", name)

  # ROC curve 提取正向值进行对比
  fpr, tpr, _ = roc_curve(y_test, model.predict_proba(X_test)[:,1])
  roc_auc = auc(fpr, tpr)

  plt.figure()
  plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
  plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
  plt.xlim([0.0, 1.0])
  plt.ylim([0.0, 1.05])
  plt.xlabel('False Positive Rate')
  plt.ylabel('True Positive Rate')
  plt.title('Receiver Operating Characteristic - {}'.format(name))
  plt.legend(loc="lower right")
  plt.grid()
相关文章
|
14天前
|
机器学习/深度学习 存储 设计模式
特征时序化建模:基于特征缓慢变化维度历史追踪的机器学习模型性能优化方法
本文探讨了数据基础设施设计中常见的一个问题:数据仓库或数据湖仓中的表格缺乏构建高性能机器学习模型所需的历史记录,导致模型性能受限。为解决这一问题,文章介绍了缓慢变化维度(SCD)技术,特别是Type II类型的应用。通过SCD,可以有效追踪维度表的历史变更,确保模型训练数据包含完整的时序信息,从而提升预测准确性。文章还从数据工程师、数据科学家和产品经理的不同视角提供了实施建议,强调历史数据追踪对提升模型性能和业务洞察的重要性,并建议采用渐进式策略逐步引入SCD设计模式。
28 8
特征时序化建模:基于特征缓慢变化维度历史追踪的机器学习模型性能优化方法
|
17天前
|
机器学习/深度学习 人工智能 算法
机器学习算法的优化与改进:提升模型性能的策略与方法
机器学习算法的优化与改进:提升模型性能的策略与方法
124 13
机器学习算法的优化与改进:提升模型性能的策略与方法
|
6天前
|
机器学习/深度学习 安全 PyTorch
FastAPI + ONNX 部署机器学习模型最佳实践
本文介绍了如何结合FastAPI和ONNX实现机器学习模型的高效部署。面对模型兼容性、性能瓶颈、服务稳定性和安全性等挑战,FastAPI与ONNX提供了高性能、易于开发维护、跨框架支持和活跃社区的优势。通过将模型转换为ONNX格式、构建FastAPI应用、进行性能优化及考虑安全性,可以简化部署流程,提升推理性能,确保服务的可靠性与安全性。最后,以手写数字识别模型为例,展示了完整的部署过程,帮助读者更好地理解和应用这些技术。
42 18
|
10天前
|
机器学习/深度学习 人工智能 自然语言处理
云上一键部署 DeepSeek-V3 模型,阿里云 PAI-Model Gallery 最佳实践
本文介绍了如何在阿里云 PAI 平台上一键部署 DeepSeek-V3 模型,通过这一过程,用户能够轻松地利用 DeepSeek-V3 模型进行实时交互和 API 推理,从而加速 AI 应用的开发和部署。
|
3天前
如何看PAI产品下训练(train)模型任务的费用细节
PAI产品下训练(train)模型任务的费用细节
18 4
|
2月前
|
人工智能 JSON 算法
Qwen2.5-Coder 系列模型在 PAI-QuickStart 的训练、评测、压缩及部署实践
阿里云的人工智能平台 PAI,作为一站式、 AI Native 的大模型与 AIGC 工程平台,为开发者和企业客户提供了 Qwen2.5-Coder 系列模型的全链路最佳实践。本文以Qwen2.5-Coder-32B为例,详细介绍在 PAI-QuickStart 完成 Qwen2.5-Coder 的训练、评测和快速部署。
Qwen2.5-Coder 系列模型在 PAI-QuickStart 的训练、评测、压缩及部署实践
|
1月前
|
编解码 机器人 测试技术
技术实践 | 使用 PAI+LLaMA Factory 微调 Qwen2-VL 模型快速搭建专业领域知识问答机器人
Qwen2-VL是一款具备高级图像和视频理解能力的多模态模型,支持多种语言,适用于多模态应用开发。通过PAI和LLaMA Factory框架,用户可以轻松微调Qwen2-VL模型,快速构建文旅领域的知识问答机器人。本教程详细介绍了从模型部署、微调到对话测试的全过程,帮助开发者高效实现定制化多模态应用。
|
2月前
|
机器学习/深度学习 PyTorch API
优化注意力层提升 Transformer 模型效率:通过改进注意力机制降低机器学习成本
Transformer架构自2017年被Vaswani等人提出以来,凭借其核心的注意力机制,已成为AI领域的重大突破。该机制允许模型根据任务需求灵活聚焦于输入的不同部分,极大地增强了对复杂语言和结构的理解能力。起初主要应用于自然语言处理,Transformer迅速扩展至语音识别、计算机视觉等多领域,展现出强大的跨学科应用潜力。然而,随着模型规模的增长,注意力层的高计算复杂度成为发展瓶颈。为此,本文探讨了在PyTorch生态系统中优化注意力层的各种技术,
111 6
优化注意力层提升 Transformer 模型效率:通过改进注意力机制降低机器学习成本
|
1月前
|
机器学习/深度学习 人工智能 算法
人工智能浪潮下的编程实践:构建你的第一个机器学习模型
在人工智能的巨浪中,每个人都有机会成为弄潮儿。本文将带你一探究竟,从零基础开始,用最易懂的语言和步骤,教你如何构建属于自己的第一个机器学习模型。不需要复杂的数学公式,也不必担心编程难题,只需跟随我们的步伐,一起探索这个充满魔力的AI世界。
57 12
|
2月前
|
机器学习/深度学习 Python
机器学习中评估模型性能的重要工具——混淆矩阵和ROC曲线。混淆矩阵通过真正例、假正例等指标展示模型预测情况
本文介绍了机器学习中评估模型性能的重要工具——混淆矩阵和ROC曲线。混淆矩阵通过真正例、假正例等指标展示模型预测情况,而ROC曲线则通过假正率和真正率评估二分类模型性能。文章还提供了Python中的具体实现示例,展示了如何计算和使用这两种工具来评估模型。
89 8

热门文章

最新文章