tensorflow杂项

简介: 杂项,后期整理1.tf.sampled_softmax_loss此函数和tf.nn.nce_loss是差不多的, 取样求lossdef sampled_softmax_loss( weights...

杂项,后期整理

1.tf.sampled_softmax_loss

此函数和tf.nn.nce_loss是差不多的, 取样求loss

def sampled_softmax_loss(
                          weights, #[num_classes, dim]
                         biases,  #[num_classes]
                         inputs,  #[batch_size, dim]
                         labels,  #[batch_size, num_true]
                         num_sampled,
                         num_classes,
                         num_true=1,
                         sampled_values=None,
                         remove_accidental_hits=True,
                         partition_strategy="mod",
                         name="sampled_softmax_loss")

关于参数labels:一般情况下,num_true为1, labels的shpae为[batch_size, 1]。假设我们有1000个类别, 使用one_hot形式的label的话, 我们的labels的shape是[batch_size, num_classes]。显然,如果num_classes非常大的话,会影响计算性能。所以,这里采用了一个稀疏的方式,即:使用3代表了[0,0,0,1,0….]

2.tf.contrib.layers.embed_sequence

tf.contrib.layers.embed_sequenceembed_sequence(
    ids,
    vocab_size=None,
    embed_dim=None,
    unique=False,
    initializer=None,
    regularizer=None,
    trainable=True,
    scope=None,
    reuse=None
)

一般用于seq2seq网络,可以完成对输入序列数据的嵌入工作。一般只需关注前三个参数即可。
ids:形状为[batch_size, seq_length],也就是输入数据
vocab_size:输入数据为字典的长度
embed_dim:想要的嵌入矩阵的维度

3.tf.contrib.layers.embed_sequence

tf.stride_slice(data, begin, end)
tf.slice(data, begin, end)
tf.stride_slice的end是开区间,tf.slice的end是闭区间。
一般有一个常用的小技巧是tf.stride_slice(data, [0, 0], [rows, -1]),可以截掉最后一列,很实用。

import tensorflow as tf
data = [[[1, 1, 1], [2, 2, 2]],
            [[3, 3, 3], [4, 4, 4]],
            [[5, 5, 5], [6, 6, 6]]]
x = tf.strided_slice(data,[0,0,0],[1,1,1])
with tf.Session() as sess:
    print(sess.run(x))
#就是第0行到第一行,然后弄完[[[1, 1, 1], [2, 2, 2]]],再第0行到第一行,弄完
#[[[1, 1, 1]]],再到第0列到第一列,弄完[[[1]]]

4.tf.contrib.seq2seq.TrainingHelper

A helper for use during training. Only reads inputs.
Returned sample_ids are the argmax of the RNN output logits.

helper = tf.contrib.seq2seq.TrainingHelper(
    input=input_vectors,
    sequence_length=input_lengths)

5.tf.contrib.seq2seq.BasicDecoder

用于构造一个decoder

decoder = tf.contrib.seq2seq.BasicDecoder(
    cell=cell,
    helper=helper,
    initial_state=cell.zero_state(batch_size, tf.float32))

6.tf.contrib.seq2seq.dynamic_decode

用于构造一个动态decoder,返回的内容是:就是3个返回值
(final_outputs, final_state, final_sequence_lengths),final_outputs是一个元祖,包含两项(rnn_outputs, sample_id),他们的说明如下:
rnn_output:[batch_size, decoder_targets_length, vocab_size],保存是decoder中每个时间步的输出,很容易理解。
sample_id:[batch_size],保存最终的编码结果,可以表示最后的答案

#新版本的tensorflow是3个返回值,注意了
outputs, _ , _= tf.contrib.seq2seq.dynamic_decode(
   decoder=decoder,
   output_time_major=False,
   impute_finished=True,
   maximum_iterations=20)

7.tf.tile

通过复制input的multiples时间来创建新的张量。输出的张量的第i维是input[i] * multiples[i] ,通俗点就是沿着“i”维度input值被复制了multiples[i]次。比如:

a = tf.constant([[1, 2],[2, 3],[3, 4]], dtype=tf.float32)
tile_a_1 = tf.tile(a, [3,1])

with tf.Session() as sess:
    print(sess.run(tile_a_1))
#结果是:
[[1. 2.]
 [2. 3.]
 [3. 4.]
 [1. 2.]
 [2. 3.]
 [3. 4.]
 [1. 2.]
 [2. 3.]
 [3. 4.]]

8.tf.contrib.seq2seq.GreedyEmbeddingHelper

这是用于seq2seq中帮助建立Decoder的一个类,在预测时使用,示例代码如下:

helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
      embedding=embedding,
      #可以结合7,知道为啥tf.tile要batch_size了吧
      start_tokens=tf.tile([GO_SYMBOL], [batch_size]),
      end_token=END_SYMBOL)

9.tf.sequence_mask

函数原型

sequence_mask(
    lengths,
    maxlen=None, #maxlen:标量整数张量,返回张量的最后维度的大小;默认值是lengths中的最大值。
    dtype=tf.bool,
    name=None
)

返回一个表示每个单元的前N个位置的mask张量。如果lengths的形状为[d_1, d_2, ..., d_n],由此产生的张量mask有dtype类型和形状[d_1, d_2, ..., d_n, maxlen],并且:mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])
例如:

tf.sequence_mask([1, 3, 2], 5)  # [[True, False, False, False, False],
                                #  [True, True, True, False, False],
                                #  [True, True, False, False, False]]

tf.sequence_mask([[1, 3],[2,0]])  # [[[True, False, False],
                                  #   [True, True, True]],
                                  #  [[True, True, False],
                                  #   [False, False, False]]]
#这个就是最大长度是3,所以3列

10.tf.contrib.seq2seq.sequence_loss

sequence_loss(
    logits,
    targets,
    weights,
    average_across_timesteps=True,
    average_across_batch=True,
    softmax_loss_function=None,
    name=None
)

用于计算seq2seq中的loss。当我们的输入是不定长的时候,weights参数常常使用我们tf.sequence_mask()中得到的mask。

11.tf.train.AdamOptimizer

训练需要优化器,这边主要不是讲优化器(真脑残,为啥一个函数分成两个来写)。其实可以用tf.train.Optimizer.minimize(),这个函数用于最小化loss,并更新var_list。这个函数可以拆分成两个函数实现同样的功能:

#该函数对var_list中的变量计算loss的梯度,为minimize()的第一部分,返回一个以#元祖(gradient, variavle)组成的列表
tf.train.Optimizer.compute_gradients(
    loss,var_list=None, gate_gradients=1,
    aggregation_method=None,
    colocate_gradients_with_ops=False, grad_loss=None)


#该函数将计算出的梯度应用到变量上,是函数minimize()的第二部分,返回一个应#用指定梯度的操作,对global_step做自增操作。
tf.train.Optimizer.apply_gradients(
    grads_and_vars, global_step=None, name=None
) 

#上述两个函数组合起来就能对loss进行优化
目录
相关文章
|
算法 前端开发 JavaScript
JS - 前端生成 UUID 四种方法
JS - 前端生成 UUID 四种方法
6531 0
|
Java 编译器
Java中环境变量 PATH 与 CLASSPATH 的区别
Java中环境变量 PATH 与 CLASSPATH 的区别
280 0
|
运维 程序员 数据库
如何用TCC方案轻松实现分布式事务一致性
TCC(Try-Confirm-Cancel)是一种分布式事务解决方案,将事务拆分为尝试、确认和取消三步,确保在分布式系统中实现操作的原子性。它旨在处理分布式环境中的数据一致性问题,通过预检查和资源预留来降低失败风险。TCC方案具有高可靠性和灵活性,但也增加了系统复杂性并可能导致性能影响。它需要为每个服务实现Try、Confirm和Cancel接口,并在回滚时确保资源正确释放。虽然有挑战,TCC在复杂的分布式系统中仍被广泛应用。
840 5
|
监控 Java 编译器
JVM常用命令及其用法,简直太全了!
JVM常用命令及其用法,简直太全了!
801 0
|
机器学习/深度学习 存储 PyTorch
【Pytorch】使用pytorch进行张量计算、自动求导和神经网络构建
【Pytorch】使用pytorch进行张量计算、自动求导和神经网络构建
435 1
|
机器学习/深度学习 存储 传感器
Unsupervised Learning | 对比学习——13篇论文综述
Unsupervised Learning | 对比学习——13篇论文综述
3060 0
Unsupervised Learning | 对比学习——13篇论文综述
|
编解码 TensorFlow 算法框架/工具
ConvNext模型复现--CVPR2022
ConvNet和Vision Transformer的ImageNet分类结果。我们证明了标准的 ConvNet 模型可以实现与分层视觉 Transformer 相同的可扩展性,同时在设计上要简单得多。
1899 0
ConvNext模型复现--CVPR2022
|
存储 tengine 算法
【端智能】MNN CPU性能优化年度小结
2020年5月,MNN发布了1.0.0版本,作为移动端/服务端/PC均适用的推理引擎,在通用性与高性能方面处于业界领先水平。
【端智能】MNN CPU性能优化年度小结
|
前端开发
controller层设计
MVC架构下,我们的web工程结构会分为三层,自下而上是dao层,service层和controller层。controller层为控制层,主要处理外部请求。调用service层,一般情况下,controller层不应该包含业务逻辑,controller的功能应该有以下五点: ⑴、接收请求并解析参数 ⑵、业务逻辑执行成功做出响应 ⑶、异常处理 ⑷、转换业务对象 ⑸、调用 Service 接口
|
缓存 JavaScript 开发工具
NPM 介绍
简介 NPM 是随同 NodeJS 一起安装的包管理工具,能解决 NodeJS 代码部署上的很多问题,常见的使用场景有以下几种: • 允许用户从NPM服务器下载别人编写的第三方包到本地使用。 • 允许用户从NPM服务器下载并安装别人编写的命令行程序到本地使用。 • 允许用户将自己编写的包或命令行程序上传到NPM服务器供别人使用。 由于新版的 nodejs 已经集成了 npm,所以之前 npm也一并安装好了。同样可以通过输入 "npm -v" 来测试是否成功安装。命令如下,出现版本提示表示安装成功:
1010 0