使用度量学习进行特征嵌入:交叉熵和监督对比损失的效果对比

本文涉及的产品
函数计算FC,每月15万CU 3个月
简介: 使用度量学习进行特征嵌入:交叉熵和监督对比损失的效果对比

分类是机器学习中最简单,最常见的任务之一。例如,在计算机视觉中,您希望能够微调普通卷积神经网络(CNN)的最后一层,以将样本正确分类为某些类别(类)。但是,有几种根本不同的方法可以实现这一目标。

Metric learning(度量学习)是其中之一,今天我想与大家分享如何正确使用它。为了使事情变得实用,我们将研究监督式对比学习(SupCon),它是对比学习的一部分,而后者又是度量学习的一部分,但稍后会介绍更多。

通常如何进行分类

在进行度量学习之前,首先了解通常如何解决分类任务。卷积神经网络是当今实用计算机视觉最重要的思想之一,它由两部分组成:编码器和头部(在这种情况下为分类器)。

640.png

首先-拍摄图像并计算一组特征,这些特征可以捕获该图像的重要信息。这是通过卷积和池化操作完成的(这就是为什么它被称为卷积神经网络)。之后,将这些特征解压缩到单个向量中,并使用常规的全连接神经网络执行分类。在实践中,您采用在大型数据集(例如ImageNet)上预先训练的某种模型(例如ResNet,DenseNet,EfficientNet等),并根据您的任务(仅最后一层或整个模型)进行微调)。

然而,这里有几点需要注意。首先,通常只关心网络FC部分的输出。也就是说,你取它的输出,并把它们提供给损失函数,以保持模型学习。换句话说,您并不真正关心网络中间发生了什么(例如,来自编码器的特性)。其次,通常你用一些基本的损失函数来训练这些东西,比如交叉熵。

640.png

为了更好地理解这个2步过程(encoder + FC),你可以这样想:encoder将图像映射到一些高维空间(例如,在ResNet18的情况下,我们讨论的是512维,而对于Resnet101 - 2048)。在此之后,FC的目标是在这些代表样本的点之间画一条线,以便将它们映射到类。这两种东西是同时训练的。因此,你试图优化特征,同时“在高维空间中画线”。

这种方法有什么问题吗?嗯,没什么,真的。它实际上运行得很好。但这并不意味着没有别的办法。

度量学习 Metric learning

现代机器学习中最有趣的想法之一(至少对我来说是这样)叫做度量学习(或深度度量学习)。简单地说:如果我们不去关注FC层的输出,而是更仔细地研究编码器生成的特性会怎样?如果我们设法用一些损耗函数来优化这些特性,而不是使用网络输出进行优化,会怎么样呢?这就是度量学习的意义所在:用编码器生成好的特性(嵌入)。

“好”是什么意思呢?好吧,如果你想一下,在计算机视觉的例子中,你想对相似的图像有相似的特征,而对截然不同的图像有截然不同的特征。

监督对比学习 Supervised Contrastive Learning

640.png

好的,假设在度量学习中,我们关心的只是“好”特征。但是监督式对比学习有什么意义呢?老实说,这种特定方法没有什么特别之处。这是最近的一篇论文,提出了一些不错的技巧,以及一个有趣的2步方法

  1. 训练一个好的编码器,该编码器能够为图像生成良好的特征。
  2. 冻结编码器,添加FC层,然后进行训练。

您可能想知道常规分类器训练有什么区别。不同之处在于,在常规培训中,您需要同时训练编码器和FC。另一方面,在这里,您首先训练一个不错的编码器,然后将其冻结(不再训练),然后仅训练FC。这种逻辑背后的想法是,如果我们设法首先为图像生成真正好的特征,则应该很容易优化FC(正如我们前面提到的,其目标是优化分离样本的行)。

训练过程的细节

让我们深入了解SupCon实施的细节。

640.png

在查看训练循环之前,您应该了解的一件事是要训练哪种模型。这非常简单:编码器(例如ResNet,DenseNet,EffNet等),但没有常规的FC层进行分类。

这里不是分类头,而是投影头。投影头是一个由2个FC层组成的序列,它将编码器的特征映射到一个较低的维度空间(通常是128维度,你甚至可以在上面的图片中看到这个值)。使用投影头的原因是,与来自编码器的几千个特征相比,使用128个精心选择的特征更容易让模型学习。

  1. 构造一批N个图像。与其他度量学习方法不同,您不需要太关心这些样本的选择。能拿多少就拿多少,剩下的由损失来处理。
  2. 将这些图像以成对的方式转发给网络,其中一对图像被构造为[augmentation(image_i), augmentation(image_i)],得到embeddings。并进行标准化。
  3. 以某个图像做为锚点。在批处理中找到同一个类的所有图像。把它们作为正样本。找到所有不同类的图像。把他们当作负样本。
  4. 将SupCon损失应用于第二步归一化嵌入,使正样本彼此靠近,同时使负样本更远离。
  5. 第一阶段训练完成后,删除投影头,并在编码器顶部添加FC(就像在常规分类训练中一样)。开始第二阶段训练的冻结编码器,并微调FC的训练。

这里要记住几件事。首先,在训练完成后,去掉投影头,使用投影头之前的特征是会获得更好的效果。作者解释说,由于我们降低了嵌入的大小,导致信息丢失。其次,增强的选择很重要。作者提出了裁剪和色彩抖动的组合。Supcon一次处理批处理中的所有图像(因此,无需构造对或三元组)。而且批处理中的图像越多,模型学习起来就越容易(因为SupCon具有隐式的正负硬挖掘质量)。第四,你可以在第4步停止。这意味着可以通过嵌入来进行分类,而不需要任何FC层。为了做到这一点,计算所有训练样本的嵌入。然后,在验证时,对每个样本计算一个嵌入,将其与每个训练嵌入进行比较(例如余弦距离),采用其类别。

PyTorch实现

实际上,在PyTorch中有一个SupCon的半官方实现。不幸的是,它包含了非常恼人的隐藏bug。最严重的一个问题是:repo的创造者使用了他自己的resnet实现,由于其中的一些bug,批量大小比普通的torchvision模型低两倍。最重要的是,repo没有验证或可视化,所以你不知道什么时候停止训练。在我的repo中,我修复了所有这些问题,并为稳定的训练增加了更多的技巧。

更准确地说,在我的实现包含了以下功能:

  • 使用albumentations进行扩增
  • Yaml配置
  • t-SNE可视化
  • 使用AMI、NMI、mAP、precision_at_1等PyTorch度量学习进行2步验证(用于投影头前后的特性)。
  • 指数移动平均更稳定的训练,随机移动平均更好的泛化和整体性能。
  • 自动混合精度训练,以便能够训练更大的批大小(大约是2的倍数)。
  • 标签平滑损失,LRFinder为第二阶段的训练(FC)。
  • 支持timm模型和jettify优化器
  • 固定种子,使训练具有确定性。
  • 保存基于验证的权重,日志-定期。txt文件,以及TensorBoard日志。

例子是使用Cifar10和Cifar100数据集来进行测试的,但是添加自己的数据集非常简单。为了运行整个数据处理管道,请执行以下操作:

pythontrain.py--config_nameconfigs/train/train_supcon_resnet18_cifar10_stage1.ymlpythonswa.py--config_nameconfigs/train/swa_supcon_resnet18_cifar100_stage1.ymlpythontrain.py--config_nameconfigs/train/train_supcon_resnet18_cifar10_stage2.ymlpythonswa.py--config_nameconfigs/train/swa_supcon_resnet18_cifar100_stage2.yml

之后,你可以检查可视化t-SNE结果。例如,对于Cifar10和Cifar100,大概是下面这样:

640.png

Cifar10 t-SNE,  SupCon 损失

640.png

Cifar10 t-SNE, Cross Entropy 损失

640.png

Cifar100 t-SNE,  SupCon 损失

640.png

Cifar10 t-SNE, Cross Entropy 损失

总结

度量学习是一个非常强大的东西。但是要达到常规CE / LabelSmoothing可以提供的准确性水平非常困难。此外,在训练期间它在计算上也可能是昂贵的并且不稳定的。我在各种任务(分类,超出分布的预测,对新类的泛化等)上测试了SupCon和其他度量指标损失,使用诸如SupCon之类的优势尚不确定。

那有什么意义?我个人认为有两件事。第一,SupCon(和其他度量学习方法)仍然可以提供比CE更结构化的集群,因为它直接优化了该属性。第二,多一个你可以尝试的技能/工具仍然是非常有益的。因此,通过更好的扩展集或不同的数据集(可能使用更细粒度的类),SupCon 可能会产生更好的结果,而不仅仅是与常规分类训练相当。

相关实践学习
【文生图】一键部署Stable Diffusion基于函数计算
本实验教你如何在函数计算FC上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。函数计算提供一定的免费额度供用户使用。本实验答疑钉钉群:29290019867
建立 Serverless 思维
本课程包括: Serverless 应用引擎的概念, 为开发者带来的实际价值, 以及让您了解常见的 Serverless 架构模式
目录
相关文章
|
8月前
|
机器学习/深度学习
为什么在二分类问题中使用交叉熵函数作为损失函数
为什么在二分类问题中使用交叉熵函数作为损失函数
312 2
|
机器学习/深度学习 人工智能 测试技术
使用随机森林分类器对基于NDRE(归一化差异水体指数)的特征进行分类
使用随机森林分类器对基于NDRE(归一化差异水体指数)的特征进行分类
108 1
|
3月前
|
机器学习/深度学习 调度 知识图谱
TimeDART:基于扩散自回归Transformer 的自监督时间序列预测方法
近年来,深度神经网络成为时间序列预测的主流方法。自监督学习通过从未标记数据中学习,能够捕获时间序列的长期依赖和局部特征。TimeDART结合扩散模型和自回归建模,创新性地解决了时间序列预测中的关键挑战,在多个数据集上取得了最优性能,展示了强大的泛化能力。
119 0
TimeDART:基于扩散自回归Transformer 的自监督时间序列预测方法
|
3月前
|
机器学习/深度学习 自然语言处理
交叉熵损失
【10月更文挑战第2天】
|
6月前
|
机器学习/深度学习
交叉熵损失函数的使用目的(很肤浅的理解)
交叉熵损失函数的使用目的(很肤浅的理解)
|
8月前
|
机器学习/深度学习
用SPSS估计HLM多层(层次)线性模型模型
用SPSS估计HLM多层(层次)线性模型模型
|
8月前
|
机器学习/深度学习 数据采集 算法
乳腺癌预测:特征交叉+随机森林=成功公式?
乳腺癌预测:特征交叉+随机森林=成功公式?
102 0
乳腺癌预测:特征交叉+随机森林=成功公式?
|
vr&ar
用于非线性时间序列预测的稀疏局部线性和邻域嵌入(Matlab代码实现)
用于非线性时间序列预测的稀疏局部线性和邻域嵌入(Matlab代码实现)
135 0
用于非线性时间序列预测的稀疏局部线性和邻域嵌入(Matlab代码实现)
|
机器学习/深度学习 监控
使用2D卷积技术进行时间序列预测(上)
使用2D卷积技术进行时间序列预测
262 1
使用2D卷积技术进行时间序列预测(上)
|
机器学习/深度学习
使用2D卷积技术进行时间序列预测(下)
使用2D卷积技术进行时间序列预测
487 1
使用2D卷积技术进行时间序列预测(下)