如何选择最佳的最近邻算法

简介: 如何选择最佳的最近邻算法

介绍一种通过数据驱动的方法,在自定义数据集上选择最快,最准确的ANN算法

640.png

人工神经网络背景

KNN是我们最常见的聚类算法,但是因为神经网络技术的发展出现了很多神经网络架构的聚类算法,例如 一种称为HNSW的ANN算法与sklearn的KNN相比,具有380倍的速度,同时提供了99.3%的相同结果。

为了测试更多的算法,我们整理了几种ANN算法,例如

  • Spotify’s ANNOY
  • Google’s ScaNN
  • Facebook’s Faiss
  • HNSW(Hierarchical Navigable Small World graphs)
  • 一些其他算法

作为数据科学家,我我们这里将制定一个数据驱动型决策来决定那种算法适合我们的数据。在本文中,我将演示一种数据驱动的方法,通过使用出色的an-benchmarks GitHub存储库,确定哪种ANN算法是自定义数据集的最佳选择。

640.png

下图是通过使用距离度量在glove-100 数据集上运行ANN基准而得到的图形。在此数据集上,scann算法在任何给定的Recall中具有最高的每秒查询数,因此在该数据集上具有最佳的算法。

640.png

总流程

这些是在自定义数据集上运行ann-benchmarks代码的步骤。

  • 在python 3.6环境中安装ann-benchmarks。
  • 将自定义嵌入数据框架上传到anbenchmarks / data目录。
  • 更新ann-benchmarks / ann-benchmarks / dataset.py,以读取并拆分新的自定义DataFrame。
  • 运行基准测试代码。
  • 绘制结果

1.在python 3.6环境中安装ann-benchmarks

此步骤的代码需要在终端中执行。我在使用anaconda进行环境设置。这将需要几分钟才能完成。您可以使用proc参数增加并发进程的数量,从而加快速度。我仅在安装完成后才升级pandas和scipy。

在撰写本文时,Ann基准仅支持Python 3.6。

condacreate-nannpython=3.6jupyterlab-ycondaactivateanngitclonehttps://github.com/erikbern/ann-benchmarks.gitcdann-benchmarks/pipinstall-rrequirements.txtpythoninstall.py--proc=8pipinstall--upgradepandasscipymkdirdata

可能出现的问题:

  • 未安装gcc:使用sudo apt-get install gcc安装GCC。
  • 权限问题:如果在运行python install.py时遇到任何权限问题,只需使用sudo / opt / conda / envs / ann / bin / python install.py即可运行它。使用sudo时,请记住在您的环境中提供anaconda python的完整路径。

2.上传自定义DataFrame

在此步骤中,将自定义数据框架文件复制到ann-benchmarks / data目录中。对于这篇文章,我的DataFrame与使用的带有FastText句子嵌入的[Amazon产品数据集]。但是,我只是随机抽样5万行,以确保基准测试能够在合理的时间内运行。以下是将嵌入数据框保存为正确目录中名为custom-euclidean.pkl的文件的代码,也是该数据框前5行的摘录。

df.to_pickle('ann-benchmarks/data/custom-euclidean.pkl')
df.head()

640.png

3.更新datasets.py以处理您的自定义DataFrame

我们需要更新ANN基准代码,编写我们的新的DataFrame处理代码。我们在ann-benchmarks / ann-benchmarks / datasets.py文件的末尾添加了一个新的function和dictionary元素。距离参数的允许选项是“euclidean”,“angular”,“hamming”或“jaccard”。距离度量的选择特定于您的问题。就我而言,我发现“欧式距离”提供了最好的结果

defcustom_dataset(out_fn, test_ratio, distance):
#Functiontohandleourcustomdatasetimportpandasaspd#ReadtheDataFrame#out_fnisoftheform'data/<dataset-name>.hdf5'df=pd.read_pickle(out_fn.split('.')[0]+'.pkl')
#ConvertsingleembeddingcolumntonumpylistoflistsX=pd.DataFrame(df['emb'].tolist()).to_numpy()    
#SplitTrainandTestX_train, X_test=train_test_split(X, test_size=test_ratio)
#WriteHDF5Outputwrite_output(X_train, X_test, out_fn, distance)
#Createanewdictionaryelementtocallournewfunction#20%ofrowsusedasTestSet#EuclideandistanceusedasmeasureforfindingneighborsDATASETS['custom-euclidean'] =lambdaout_fn: custom_dataset(out_fn, test_ratio=0.2, distance='euclidean')

4.运行基准测试代码

如果到目前为止一切顺利,我们现在可以通过从终端调用以下代码来运行基准测试。将并行性的值更改为要使用的尽可能多的CPU内核。我使用的是16核CPU,因此我选择parallelism = 14来为其他任务保留2核。这将需要一些时间才能完成。我的具有20%测试集的5万数据行了大约7个小时。

pythonrun.py--dataset='custom-euclidean'--parallelism=14

5.绘制结果

运行完成后,我们可以通过运行plot.py绘制结果。我们还可以使y轴以对数比例绘制。请注意,我在使用sudo时使用了Anaconda Python的完整路径,因为在尝试正常运行plot.py时遇到权限问题:python plot.py --dataset = custom-euclidean --y-log。您可以使用任何适合您的方法。

sudo/opt/conda/envs/ann/bin/pythonplot.py--dataset=custom-euclidean--y-log

结果图将作为png文件保存在结果目录中。对于我在本文中使用的5万行Amazon数据集,结果如下。

640.png

从该图中可以看出,通过在任意给定的Recall上每秒提供更高的查询,诸如NGT-onng,hnsw(nmslib),n2,hnswlib,SW-graph(nmslib)之类的算法明显优于其余算法。因此,我们可以在亚马逊产品数据集上为我们的项目进一步探索这些算法。

总结

总之,通过使用ann-benchmarks,并编写一些自定义的代码,我们可以 在自己的自定义数据集上测试大量的ANN算法,以缩小筛选范围,以进一步探索。这篇文章的所有代码都可以在我的Github存储库中找到。感谢您的阅读!

代码地址:https://github.com/stephenleo/adventures-with-ann/blob/main/ann_benchmarking.ipynb

目录
相关文章
|
存储 运维 算法
UUID和雪花(Snowflake)算法该如何选择?
UUID和雪花(Snowflake)算法该如何选择?
344 0
|
机器学习/深度学习 数据采集 算法
Py之scikit-learn:机器学习Sklearn库的简介、安装、使用方法(ML算法如何选择)、代码实现之详细攻略
Py之scikit-learn:机器学习Sklearn库的简介、安装、使用方法(ML算法如何选择)、代码实现之详细攻略
Py之scikit-learn:机器学习Sklearn库的简介、安装、使用方法(ML算法如何选择)、代码实现之详细攻略
|
18天前
|
算法
基于WOA算法的SVDD参数寻优matlab仿真
该程序利用鲸鱼优化算法(WOA)对支持向量数据描述(SVDD)模型的参数进行优化,以提高数据分类的准确性。通过MATLAB2022A实现,展示了不同信噪比(SNR)下模型的分类误差。WOA通过模拟鲸鱼捕食行为,动态调整SVDD参数,如惩罚因子C和核函数参数γ,以寻找最优参数组合,增强模型的鲁棒性和泛化能力。
|
24天前
|
机器学习/深度学习 算法 Serverless
基于WOA-SVM的乳腺癌数据分类识别算法matlab仿真,对比BP神经网络和SVM
本项目利用鲸鱼优化算法(WOA)优化支持向量机(SVM)参数,针对乳腺癌早期诊断问题,通过MATLAB 2022a实现。核心代码包括参数初始化、目标函数计算、位置更新等步骤,并附有详细中文注释及操作视频。实验结果显示,WOA-SVM在提高分类精度和泛化能力方面表现出色,为乳腺癌的早期诊断提供了有效的技术支持。
|
4天前
|
供应链 算法 调度
排队算法的matlab仿真,带GUI界面
该程序使用MATLAB 2022A版本实现排队算法的仿真,并带有GUI界面。程序支持单队列单服务台、单队列多服务台和多队列多服务台三种排队方式。核心函数`func_mms2`通过模拟到达时间和服务时间,计算阻塞率和利用率。排队论研究系统中顾客和服务台的交互行为,广泛应用于通信网络、生产调度和服务行业等领域,旨在优化系统性能,减少等待时间,提高资源利用率。
|
12天前
|
存储 算法
基于HMM隐马尔可夫模型的金融数据预测算法matlab仿真
本项目基于HMM模型实现金融数据预测,包括模型训练与预测两部分。在MATLAB2022A上运行,通过计算状态转移和观测概率预测未来值,并绘制了预测值、真实值及预测误差的对比图。HMM模型适用于金融市场的时间序列分析,能够有效捕捉隐藏状态及其转换规律,为金融预测提供有力工具。
|
20天前
|
算法
基于GA遗传算法的PID控制器参数优化matlab建模与仿真
本项目基于遗传算法(GA)优化PID控制器参数,通过空间状态方程构建控制对象,自定义GA的选择、交叉、变异过程,以提高PID控制性能。与使用通用GA工具箱相比,此方法更灵活、针对性强。MATLAB2022A环境下测试,展示了GA优化前后PID控制效果的显著差异。核心代码实现了遗传算法的迭代优化过程,最终通过适应度函数评估并选择了最优PID参数,显著提升了系统响应速度和稳定性。
|
12天前
|
机器学习/深度学习 算法 信息无障碍
基于GoogleNet深度学习网络的手语识别算法matlab仿真
本项目展示了基于GoogleNet的深度学习手语识别算法,使用Matlab2022a实现。通过卷积神经网络(CNN)识别手语手势,如&quot;How are you&quot;、&quot;I am fine&quot;、&quot;I love you&quot;等。核心在于Inception模块,通过多尺度处理和1x1卷积减少计算量,提高效率。项目附带完整代码及操作视频。