三元组损失Triplet loss 详解

本文涉及的产品
实时计算 Flink 版,1000CU*H 3个月
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
实时数仓Hologres,5000CU*H 100GB 3个月
简介: 在这篇文章中,我们将以简单的技术术语解析三元组损失及其变体批量三元组损失,并提供一个相关的例子来帮助你理解这些概念。

深度神经网络在识别模式和进行预测方面表现出色,但在涉及图像识别任务时,它们常常难以区分相似个体的图像。三元组损失是一种强大的训练技术,可以解决这个问题,它通过学习相似度度量,在高维空间中将相似图像准确地嵌入到彼此接近的位置。 在这篇文章中,我们将以简单的技术术语解析三元组损失及其变体批量三元组损失,并提供一个相关的例子来帮助你理解这些概念。

三元组损失

三元组损失是一种用于训练神经网络的损失函数,可以用于执行诸如人脸识别或目标分类等任务。三元组损失的目标是在高维嵌入空间(也称为特征空间)中学习一种相似度度量,在这个空间中,相似对象(例如,同一个人的图像)的表示彼此接近,而不相似对象的表示则相距较远。

三元组损失的核心概念是使用三元组,它由一个锚点样本、一个正样本和一个负样本组成。锚点样本和正样本是相似的实例,而负样本则是不相似的。算法学习以这样一种方式嵌入这些样本:锚点样本与正样本之间的距离小于锚点样本与负样本之间的距离。

在实践中,三元组损失通常与一种称为孪生网络的神经网络架构一起使用,该架构在处理相同输入的两个或多个分支之间共享权重。这种共享表示允许网络在嵌入空间中学习一个稳健的相似度度量。

当锚点样本和正样本在嵌入空间中不够接近,或者锚点样本和负样本太接近时,三元组损失函数会对网络进行惩罚。这鼓励网络学习输入数据的有意义表示,捕捉相关样本之间的相似性。

三元组损失的例子

假设有一组不同人的照片,我们想训练一个人脸识别系统。目标是识别两张图像是否属于同一个人。三元组损失可以用来学习一个相似度度量,使系统能够准确识别人脸。

一个三元组由三张照片组成:一个锚点、一个正样本和一个负样本。锚点是特定人的照片,正样本是同一个人的另一张照片,负样本是不同人的照片。

在训练过程中,网络会呈现三元组,三元组损失函数计算锚点、正样本和负样本嵌入(高维特征表示)之间的距离。如果锚点和正样本嵌入之间的距离太大,或者锚点和负样本嵌入之间的距离太小,三元组损失函数就会惩罚网络。

通过基于这个损失函数迭代调整网络的权重,网络学会将相似的人脸(即锚点和正样本)嵌入到嵌入空间中彼此接近的位置,而不相似的人脸(即锚点和负样本)则被分开。

例如,如果同一个人的两张照片(锚点和正样本)的嵌入彼此接近,系统就能准确识别它们属于同一个人。相反,如果不同人的照片(锚点和负样本)的嵌入相距较远,系统就能自信地将它们归类为属于不同的个体。

批量三元组损失

批量三元组损失是传统三元组损失的一种变体,它在训练过程中对数据批次进行操作。在标准三元组损失中,一个批次由三张图像组成:一个锚点、一个正样本和一个负样本。目标是学习一个相似度度量,例如能够准确识别人脸。

而批量三元组损失,不是一次处理一个三元组,而是在一个批次中一起处理多个三元组。这种方法在计算上可能更高效,并且可以利用现代 GPU 的能力更快地训练深度神经网络。

在训练过程中,网络会呈现一批三元组,三元组损失函数计算每个三元组内锚点、正样本和负样本嵌入(高维特征表示)之间的距离。如果锚点和正样本嵌入之间的距离太大,或者锚点和负样本嵌入之间的距离太小,批量三元组损失函数就会惩罚网络。

通过基于这个损失函数迭代调整网络的权重,网络学会将相似的特征(即锚点和正样本)嵌入到嵌入空间中彼此接近的位置,而不相似的特征(即锚点和负样本)则被分开。

例如,如果同一个人的两张照片(锚点和正样本)的嵌入彼此接近,系统就能准确识别它们属于同一个人。相反,如果不同人的照片(锚点和负样本)的嵌入相距较远,系统就能自信地将它们归类为属于不同的个体。

批量三元组损失是一种有效的方法,用于训练深度神经网络进行人脸识别和其他需要相似度度量的应用。

批量三元组损失的例子

假设你是机场的一名安保人员,你的任务是在安检站识别经过的个人。我们有一个手持设备,一次显示三张照片:一个锚点、一个正样本和一个负样本。目标是快速确定锚点照片中的人是否与正样本照片中的人相同,如果不同,还需要识别负样本照片中的人。

这个场景可以被构建为一个批量三元组损失问题。手持设备本质上是在执行一个使用批量三元组损失训练的深度神经网络。锚点、正样本和负样本图像是网络的输入,输出是一组嵌入(高维特征表示),捕捉图像之间的相似性。网络被训练以最小化同一个人的嵌入之间的距离(正对),同时最大化不同人的嵌入之间的距离(负对)。

在这个安保场景中,当手持设备向你呈现一批三元组时,网络计算每个三元组内锚点、正样本和负样本图像嵌入之间的距离。如果锚点和正样本图像的嵌入之间的距离很小,你就可以自信地说它们属于同一个人。如果距离很大,你就可以将负样本图像中的人识别为一个不同的个体。

通过使用批量三元组损失和大型图像数据集训练网络,它学会将相似的图像(即同一个人的图像)嵌入到嵌入空间中彼此接近的位置,而不相似的图像(即不同人的图像)则被分开。

总结

本文介绍了三元组损失,这是一种用于训练深度神经网络的技术,主要应用于图像识别任务。三元组损失通过学习高维嵌入空间中的相似度度量,使相似图像的表示彼此接近,不相似图像的表示相距较远。

三元组损失的核心概念是使用由锚点、正样本和负样本组成的三元组进行训练。网络学习将锚点与正样本的距离最小化,同时最大化与负样本的距离。而批量三元组损失,这是一种在单个批次中处理多个三元组的变体,提高了计算效率。

https://avoid.overfit.cn/post/77f8b2530e5a473da038d4ebcd086258

作者:Jyoti Dabass, Ph.D

目录
相关文章
|
自然语言处理 算法 数据挖掘
自蒸馏:一种简单高效的优化方式
背景知识蒸馏(knowledge distillation)指的是将预训练好的教师模型的知识通过蒸馏的方式迁移至学生模型,一般来说,教师模型会比学生模型网络容量更大,模型结构更复杂。对于学生而言,主要增益信息来自于更强的模型产出的带有更多可信信息的soft_label。例如下右图中,两个“2”对应的hard_label都是一样的,即0-9分类中,仅“2”类别对应概率为1.0,而soft_label
自蒸馏:一种简单高效的优化方式
|
索引 Python
|
算法 数据库 计算机视觉
Dataset之COCO数据集:COCO数据集的简介、下载、使用方法之详细攻略
Dataset之COCO数据集:COCO数据集的简介、下载、使用方法之详细攻略
|
9月前
|
搜索推荐 物联网 PyTorch
Qwen2.5-7B-Instruct Lora 微调
本教程介绍如何基于Transformers和PEFT框架对Qwen2.5-7B-Instruct模型进行LoRA微调。
10433 34
Qwen2.5-7B-Instruct Lora 微调
|
8月前
|
机器学习/深度学习 自然语言处理 搜索推荐
自注意力机制全解析:从原理到计算细节,一文尽览!
自注意力机制(Self-Attention)最早可追溯至20世纪70年代的神经网络研究,但直到2017年Google Brain团队提出Transformer架构后才广泛应用于深度学习。它通过计算序列内部元素间的相关性,捕捉复杂依赖关系,并支持并行化训练,显著提升了处理长文本和序列数据的能力。相比传统的RNN、LSTM和GRU,自注意力机制在自然语言处理(NLP)、计算机视觉、语音识别及推荐系统等领域展现出卓越性能。其核心步骤包括生成查询(Q)、键(K)和值(V)向量,计算缩放点积注意力得分,应用Softmax归一化,以及加权求和生成输出。自注意力机制提高了模型的表达能力,带来了更精准的服务。
10340 46
|
11月前
|
计算机视觉 Python
Jetson 学习笔记(十):Picamera或者Jetcam打开树莓派CSI摄像头
本文介绍了在Jetson Nano上使用picamera和jetcam库打开树莓派CSI摄像头的方法。由于使用opencv获取CSI摄像头图像延迟高,作者推荐使用picamera,能达到20-30fps。文章提供了安装步骤、基础代码示例,并记录了一些有用的博客地址。
243 2
|
7月前
|
人工智能 自然语言处理 Linux
OSUM:告别ASR单一功能,西工大开源的语音大模型会「读心」!识别+情感分析+年龄预测等8大任务1个模型全搞定
OSUM 是西北工业大学开发的开源语音理解模型,支持语音识别、情感分析、说话者性别分类等多种任务,基于 ASR+X 训练策略,具有高效和泛化能力强的特点。
599 8
OSUM:告别ASR单一功能,西工大开源的语音大模型会「读心」!识别+情感分析+年龄预测等8大任务1个模型全搞定
|
机器学习/深度学习 人工智能 运维
[ICLR2024]基于对比稀疏扰动技术的时间序列解释框架ContraLSP
《Explaining Time Series via Contrastive and Locally Sparse Perturbations》被机器学习领域顶会ICLR 2024接收。该论文提出了一种创新的基于扰动技术的时间序列解释框架ContraLSP,该框架主要包含一个学习反事实扰动的目标函数和一个平滑条件下稀疏门结构的压缩器。论文在白盒时序预测,黑盒时序分类等仿真数据,和一个真实时序数据集分类任务中进行了实验,ContraLSP在解释性能上超越了SOTA模型,显著提升了时间序列数据解释的质量。
|
11月前
|
机器学习/深度学习 人工智能 自然语言处理
前端大模型入门(三):编码(Tokenizer)和嵌入(Embedding)解析 - llm的输入
本文介绍了大规模语言模型(LLM)中的两个核心概念:Tokenizer和Embedding。Tokenizer将文本转换为模型可处理的数字ID,而Embedding则将这些ID转化为能捕捉语义关系的稠密向量。文章通过具体示例和代码展示了两者的实现方法,帮助读者理解其基本原理和应用场景。
2942 1
|
存储 JSON 自然语言处理
数据标注工具 doccano | 命名实体识别(Named Entity Recognition,简称NER)
标注数据保存在同一个文本文件中,每条样例占一行且存储为json格式,其包含以下字段 • id: 样本在数据集中的唯一标识ID。 • text: 原始文本数据。 • entities: 数据中包含的Span标签,每个Span标签包含四个字段: • id: Span在数据集中的唯一标识ID。 • start_offset: Span的起始token在文本中的下标。 • end_offset: Span的结束token在文本中下标的下一个位置。 • label: Span类型。 • relations: 数据中包含的Relation标签,每个Relation标签包含四个字段: • id: (Span
705 0