TensorFlow中Embedding和One-Hot的区别

简介: TensorFlow中Embedding和One-Hot的区别

Embedding和One-Hot的区别?

推荐系统或者一些大型文本分类数据任务上,对于一些数据的维度非常高,而且因为稀疏型(分类)特征不能直接喂给模型,我们需要将其映射成稠密连续特征(数值型)。

假设现在有个字段为城市,分别取值为上海北京深圳,如果直接将这三个值输入到模型中,一般的模型是无法识别文本的,有些模型只能够识别数值,所以我们需要将其数值化,如果不考虑关联度,我们只需要将上海映射为0,北京映射为1,深圳映射为2,然后输入到模型中,但是这样会使原来相关无关联的数值存在大小关系。

还有一种就是较为常用的One-Hot编码,它是利用了物理中的或门,只有满足条件对应的电路才会被激活

上海:【1,0,0】
北京:【0,1,0】
深圳:【0,0,1】

这样就可以将原来的分类特征转化成数值型特征,但是这样做会有两个缺点:

  • 对于推荐系统这种数据,用户ID或者物品ID随便就有几百万,如果使用这种方式会造成维度爆炸,如果将全部分类特征进行One-Hot编码那么维度将会上升到几千万乃至几亿
  • 而且编码后的向量相互之间的相似度为0,这样会导致如果两个文本之间存在相似也会造成丢失,比如苹果和梨这两种水果我们需要将其编码如果只是01编码的化,会导致两者之间无关联,但是二者都是水果会存在一定的相似性,显然这种情况使用One-Hot编码不合适

所以Embedding应运而生,它是会将我们高维离散数据映射到低维稠密空间中,在推荐系统领域较为常用,而且在TensorFlow也已经作为一个层出现,它会将我们的离散值映射成稠密值,比如会将0映射成【0.2,0.35,0.62】

上海:【2.33,0.52,0.15,1.79】
北京:【0.28,0.31,0.14,0.10】
深圳:【1.94,0.30,0.52,0.92】

它会将上面几个文本转化成一个实值组成的向量,而且我们的Embedding可以进行训练,如果两个文本之间相似,那么它们对应的实值向量的相似度也会非常高。

TensorFlow中代码实现

# (32, 10)
input = np.random.randint(1000, size=(32, 10))
output = Embedding(input_dim=1000,
                   output_dim=64)(input)
print(output.shape)
>>> (32, 10, 64)

Embedding(input_dim,output_dim)

  • imput_dim:词汇量个数,就是你这个特征所有的类别数
  • output_dim:最终映射的维度

上面的例子我们首先定义了一个array,维度为【32,10】,就是32个样本,10个特征,然后我们Embedding参数中的input_dim输入为1000,说明我们所有语料库所有特征可以取的值总的类别有1000个,然后output_dim为64,就是会将每个实值映射成一个64维的一个向量,最终的输出维度为【32,10,64】

也就是原来每个样本有10个特征,10个数值,现在它会将这10个数,每个数映射成一个64维的向量,1个数对应1个64维向量,现在有10个数,那么每个样本从原来的10个数,变成了现在的10个64维向量,原来的每个数现在用一个64维连续向量进行描述。

它的机制是这样的,训练后,它的内部会有一个【1000,64】维的表(矩阵),因为我们设定语料库维1000,也就是总共有1000个可选的值,然后模型会进行训练,将这1000个离散值训练成一个64维的连续向量,这样就会形成一个1000个64维的数表,然后将我们数据进行映射时就会从这种表进行寻找。


目录
相关文章
|
3月前
|
机器学习/深度学习 人工智能 API
TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:1~5
TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:1~5
71 0
|
3月前
|
机器学习/深度学习 存储 人工智能
TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11(3)
TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11(3)
81 0
|
3月前
|
机器学习/深度学习 Dart TensorFlow
TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11(5)
TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11(5)
72 0
|
10天前
|
机器学习/深度学习 运维 监控
TensorFlow分布式训练:加速深度学习模型训练
【4月更文挑战第17天】TensorFlow分布式训练加速深度学习模型训练,通过数据并行和模型并行利用多机器资源,减少训练时间。优化策略包括配置计算资源、优化数据划分和减少通信开销。实际应用需关注调试监控、系统稳定性和容错性,以应对分布式训练挑战。
|
2月前
|
机器学习/深度学习 PyTorch TensorFlow
Python中的深度学习:TensorFlow与PyTorch的选择与使用
Python中的深度学习:TensorFlow与PyTorch的选择与使用
|
2月前
|
机器学习/深度学习 数据可视化 TensorFlow
基于tensorflow深度学习的猫狗分类识别
基于tensorflow深度学习的猫狗分类识别
63 1
|
3月前
|
机器学习/深度学习 PyTorch TensorFlow
【TensorFlow】深度学习框架概述&TensorFlow环境配置
【1月更文挑战第26天】【TensorFlow】深度学习框架概述&TensorFlow环境配置
|
3月前
|
机器学习/深度学习 存储 编解码
TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11(4)
TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11(4)
122 0
|
3月前
|
机器学习/深度学习 存储 算法框架/工具
TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11(2)
TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11(2)
48 0
|
3月前
|
机器学习/深度学习 存储 运维
TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11(1)
TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11(1)
56 0

热门文章

最新文章