TensorFlow中Embedding和One-Hot的区别

简介: 在推荐系统或者一些大型文本分类数据任务上,对于一些数据的维度非常高,而且因为稀疏型(分类)特征不能直接喂给模型,我们需要将其映射成稠密连续特征(数值型)。

Embedding和One-Hot的区别?


在推荐系统或者一些大型文本分类数据任务上,对于一些数据的维度非常高,而且因为稀疏型(分类)特征不能直接喂给模型,我们需要将其映射成稠密连续特征(数值型)。


假设现在有个字段为城市,分别取值为上海北京深圳,如果直接将这三个值输入到模型中,一般的模型是无法识别文本的,有些模型只能够识别数值,所以我们需要将其数值化,如果不考虑关联度,我们只需要将上海映射为0,北京映射为1,深圳映射为2,然后输入到模型中,但是这样会使原来相关无关联的数值存在大小关系。


还有一种就是较为常用的One-Hot编码,它是利用了物理中的或门,只有满足条件对应的电路才会被激活


上海:【1,0,0】

北京:【0,1,0】

深圳:【0,0,1】


这样就可以将原来的分类特征转化成数值型特征,但是这样做会有两个缺点:


  • 对于推荐系统这种数据,用户ID或者物品ID随便就有几百万,如果使用这种方式会造成维度爆炸,如果将全部分类特征进行One-Hot编码那么维度将会上升到几千万乃至几亿
  • 而且编码后的向量相互之间的相似度为0,这样会导致如果两个文本之间存在相似也会造成丢失,比如苹果和梨这两种水果我们需要将其编码如果只是01编码的化,会导致两者之间无关联,但是二者都是水果会存在一定的相似性,显然这种情况使用One-Hot编码不合适


所以Embedding应运而生,它是会将我们高维离散数据映射到低维稠密空间中,在推荐系统领域较为常用,而且在TensorFlow也已经作为一个层出现,它会将我们的离散值映射成稠密值,比如会将0映射成【0.2,0.35,0.62】


上海:【2.33,0.52,0.15,1.79】

北京:【0.28,0.31,0.14,0.10】

深圳:【1.94,0.30,0.52,0.92】


它会将上面几个文本转化成一个实值组成的向量,而且我们的Embedding可以进行训练,如果两个文本之间相似,那么它们对应的实值向量的相似度也会非常高。


TensorFlow中代码实现


# (32, 10)

input = np.random.randint(1000, size=(32, 10))


output = Embedding(input_dim=1000,

                  output_dim=64)(input)

print(output.shape)

>>> (32, 10, 64)


Embedding(input_dim,output_dim)


  • imput_dim:词汇量个数,就是你这个特征所有的类别数
  • output_dim:最终映射的维度


上面的例子我们首先定义了一个array,维度为【32,10】,就是32个样本,10个特征,然后我们Embedding参数中的input_dim输入为1000,说明我们所有语料库所有特征可以取的值总的类别有1000个,然后output_dim为64,就是会将每个实值映射成一个64维的一个向量,最终的输出维度为【32,10,64】


也就是原来每个样本有10个特征,10个数值,现在它会将这10个数,每个数映射成一个64维的向量,1个数对应1个64维向量,现在有10个数,那么每个样本从原来的10个数,变成了现在的10个64维向量,原来的每个数现在用一个64维连续向量进行描述。


它的机制是这样的,训练后,它的内部会有一个【1000,64】维的表(矩阵),因为我们设定语料库维1000,也就是总共有1000个可选的值,然后模型会进行训练,将这1000个离散值训练成一个64维的连续向量,这样就会形成一个1000个64维的数表,然后将我们数据进行映射时就会从这种表进行寻找。

目录
相关文章
|
8月前
|
机器学习/深度学习 PyTorch API
MindIE Torch快速上手
MindIE Torch 是一款高效的深度学习推理优化工具,支持 PyTorch 模型在 NPU 上的高性能部署。其核心特性包括:1) 子图与单算子混合执行,配合 torch_npu 实现高效推理;2) 支持 C++ 和 Python 编程语言,灵活适配不同开发需求;3) 兼容多种模式(TorchScript、ExportedProgram、torch.compile),覆盖广泛场景;4) 支持静态与动态 Shape 模型编译,满足多样化输入需求。通过简单易用的 API,开发者可快速完成模型加载、编译优化、推理执行及离线模型导出等全流程操作,显著提升开发效率与性能表现。
|
8月前
|
中间件 Go
Golang | Gin:net/http与Gin启动web服务的简单比较
总的来说,`net/http`和 `Gin`都是优秀的库,它们各有优缺点。你应该根据你的需求和经验来选择最适合你的工具。希望这个比较可以帮助你做出决策。
356 35
|
7月前
|
边缘计算 弹性计算 人工智能
魔搭社区大模型一键部署到阿里云边缘云(ENS)
随着大模型技术的快速发展,业界的关注点正逐步从模型训练往模型推理 转变。这一转变不仅反映了大模型在实际业务中的广泛应用需求,也体现了技术优化和工程化落地的趋势。
738 7
|
11月前
|
监控 安全 物联网
工厂人员定位管理系统方案:实现低成本高精度人员定位
蓝牙定位技术结合Lora技术,实现低成本、高效率的工厂人员定位管理,能够提升生产效率、保障安全、优化应急响应的关键工具。该系统能够实时获取工厂内人员的位置信息,为生产调度、安全监控、紧急疏散等提供精确、及时的数据支持。
580 5
|
数据采集 存储 前端开发
Puppeteer教程:使用CSS选择器点击和爬取动态数据
本文介绍如何使用Puppeteer结合CSS选择器爬取动态网页数据,以贝壳网的二手房价格为例,通过代理IP提高爬虫成功率。文章详细讲解了Puppeteer的安装和配置、代码实现及数据趋势分析,帮助读者掌握动态网页爬取技术。
474 1
Puppeteer教程:使用CSS选择器点击和爬取动态数据
|
Serverless 计算机视觉
实战| 轻松实现仰卧起坐检测与计数,手把手教学【附完整源码与详细讲解】
实战| 轻松实现仰卧起坐检测与计数,手把手教学【附完整源码与详细讲解】
|
安全 测试技术 开发者
通义千问2.5有哪些升级
通义千问2.5有哪些升级
1322 5
|
机器学习/深度学习 算法 数据挖掘
深度学习中常用损失函数介绍
选择正确的损失函数对于训练机器学习模型非常重要。不同的损失函数适用于不同类型的问题。本文将总结一些常见的损失函数,并附有易于理解的解释、用法和示例
1119 0
深度学习中常用损失函数介绍
|
存储 缓存 NoSQL
深入理解分布式缓存在现代后端系统中的应用与挑战
随着互联网技术的飞速发展,分布式缓存已成为提升后端系统性能的关键技术之一。本文将从数据导向和科学严谨的角度出发,探讨分布式缓存技术的原理、应用场景以及面临的主要挑战。通过对具体案例的分析和数据统计,我们旨在为读者提供一个全面而深入的理解框架,帮助开发者更好地设计和优化后端系统。 【7月更文挑战第20天】
285 0
|
Kubernetes 容器 Perl
K8S集群重新初始化--详细过程
K8S集群重新初始化--详细过程
1575 0