TensorFlow中Embedding和One-Hot的区别

简介: 在推荐系统或者一些大型文本分类数据任务上,对于一些数据的维度非常高,而且因为稀疏型(分类)特征不能直接喂给模型,我们需要将其映射成稠密连续特征(数值型)。

Embedding和One-Hot的区别?


在推荐系统或者一些大型文本分类数据任务上,对于一些数据的维度非常高,而且因为稀疏型(分类)特征不能直接喂给模型,我们需要将其映射成稠密连续特征(数值型)。


假设现在有个字段为城市,分别取值为上海北京深圳,如果直接将这三个值输入到模型中,一般的模型是无法识别文本的,有些模型只能够识别数值,所以我们需要将其数值化,如果不考虑关联度,我们只需要将上海映射为0,北京映射为1,深圳映射为2,然后输入到模型中,但是这样会使原来相关无关联的数值存在大小关系。


还有一种就是较为常用的One-Hot编码,它是利用了物理中的或门,只有满足条件对应的电路才会被激活


上海:【1,0,0】

北京:【0,1,0】

深圳:【0,0,1】


这样就可以将原来的分类特征转化成数值型特征,但是这样做会有两个缺点:


  • 对于推荐系统这种数据,用户ID或者物品ID随便就有几百万,如果使用这种方式会造成维度爆炸,如果将全部分类特征进行One-Hot编码那么维度将会上升到几千万乃至几亿
  • 而且编码后的向量相互之间的相似度为0,这样会导致如果两个文本之间存在相似也会造成丢失,比如苹果和梨这两种水果我们需要将其编码如果只是01编码的化,会导致两者之间无关联,但是二者都是水果会存在一定的相似性,显然这种情况使用One-Hot编码不合适


所以Embedding应运而生,它是会将我们高维离散数据映射到低维稠密空间中,在推荐系统领域较为常用,而且在TensorFlow也已经作为一个层出现,它会将我们的离散值映射成稠密值,比如会将0映射成【0.2,0.35,0.62】


上海:【2.33,0.52,0.15,1.79】

北京:【0.28,0.31,0.14,0.10】

深圳:【1.94,0.30,0.52,0.92】


它会将上面几个文本转化成一个实值组成的向量,而且我们的Embedding可以进行训练,如果两个文本之间相似,那么它们对应的实值向量的相似度也会非常高。


TensorFlow中代码实现


# (32, 10)

input = np.random.randint(1000, size=(32, 10))


output = Embedding(input_dim=1000,

                  output_dim=64)(input)

print(output.shape)

>>> (32, 10, 64)


Embedding(input_dim,output_dim)


  • imput_dim:词汇量个数,就是你这个特征所有的类别数
  • output_dim:最终映射的维度


上面的例子我们首先定义了一个array,维度为【32,10】,就是32个样本,10个特征,然后我们Embedding参数中的input_dim输入为1000,说明我们所有语料库所有特征可以取的值总的类别有1000个,然后output_dim为64,就是会将每个实值映射成一个64维的一个向量,最终的输出维度为【32,10,64】


也就是原来每个样本有10个特征,10个数值,现在它会将这10个数,每个数映射成一个64维的向量,1个数对应1个64维向量,现在有10个数,那么每个样本从原来的10个数,变成了现在的10个64维向量,原来的每个数现在用一个64维连续向量进行描述。


它的机制是这样的,训练后,它的内部会有一个【1000,64】维的表(矩阵),因为我们设定语料库维1000,也就是总共有1000个可选的值,然后模型会进行训练,将这1000个离散值训练成一个64维的连续向量,这样就会形成一个1000个64维的数表,然后将我们数据进行映射时就会从这种表进行寻找。

目录
相关文章
|
4月前
|
PyTorch 算法框架/工具
【chat-gpt问答记录】torch.tensor和torch.Tensor什么区别?
【chat-gpt问答记录】torch.tensor和torch.Tensor什么区别?
96 2
|
2月前
|
机器学习/深度学习 API 算法框架/工具
【Tensorflow+keras】Keras API三种搭建神经网络的方式及以mnist举例实现
使用Keras API构建神经网络的三种方法:使用Sequential模型、使用函数式API以及通过继承Model类来自定义模型,并提供了基于MNIST数据集的示例代码。
43 12
|
2月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
【Tensorflow+Keras】keras实现条件生成对抗网络DCGAN--以Minis和fashion_mnist数据集为例
如何使用TensorFlow和Keras实现条件生成对抗网络(CGAN)并以MNIST和Fashion MNIST数据集为例进行演示。
35 3
|
2月前
【Bert4keras】解决Key bert/embeddings/word_embeddings not found in checkpoint
在使用bert4keras进行预训练并加载模型时遇到的"bert/embeddings/word_embeddings not found in checkpoint"错误,并提供了通过重新生成权重模型来解决这个问题的方法。
47 3
|
2月前
|
机器学习/深度学习 API TensorFlow
【Tensorflow+keras】解决 Fail to find the dnn implementation.
在TensorFlow 2.0环境中使用双向长短期记忆层(Bidirectional LSTM)遇到“Fail to find the dnn implementation”错误时的三种解决方案。
50 0
|
4月前
|
存储 PyTorch 算法框架/工具
【chat-gpt问答记录】关于pytorch中的线性层nn.Linear()
【chat-gpt问答记录】关于pytorch中的线性层nn.Linear()
74 0
|
API 算法框架/工具
越来越火的tf.keras模型,这三种构建方式记住了,你就是大佬!!!
越来越火的tf.keras模型,这三种构建方式记住了,你就是大佬!!!
114 0
|
机器学习/深度学习 TensorFlow 算法框架/工具
TensorFlow HOWTO 4.1 多层感知机(分类)
TensorFlow HOWTO 4.1 多层感知机(分类)
66 0
|
搜索推荐 TensorFlow 算法框架/工具
TensorFlow中Embedding和One-Hot的区别
TensorFlow中Embedding和One-Hot的区别
147 0
|
数据挖掘 TensorFlow 算法框架/工具
TensorFlow教程(4)-Attention机制
TensorFlow教程(4)-Attention机制
408 0
TensorFlow教程(4)-Attention机制