TensorFlow2.0(4):填充与复制

简介: TensorFlow2.0(4):填充与复制

1 tf.pad()


tf.pad函数主要是用来对tensor的大小进行扩展,包括水平、垂直、深度(通道)等,方法定义如下:


pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0)

输入参数:


  • tensor:输入的tensor
  • paddings:设置填充的大小
  • mode:填充方式,默认是CONSTANT,还有REFLECT和SYMMETRIC
  • name:名称
  • constant_values:CONSTANT填充方式的填充值,默认是0


参数paddings必须是形状为(n, 2)的一个list,这里的n是tensor的秩,也就是维度大小。例如当tensor为一个shape为(12,)的tensor时,paddings必须是形如[[x,y]]的一个list,x表示在第一维度前填充值的个数,y表示在第一维度后填充值的个数:


import tensorflow as tf


a = tf.range(1,13)


a


<tf.Tensor: id=3, shape=(12,), dtype=int32, numpy=array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12], dtype=int32)>


tf.pad(a, [[3,0]])  # 3表示在第一维度前填充3个0,0表示不填充


<tf.Tensor: id=5, shape=(15,), dtype=int32, numpy=
array([ 0,  0,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12],
      dtype=int32)>


当tensor是二维时,paddings必须是shape为(2,2)的list:


a = tf.reshape(a, [3, 4])


a


<tf.Tensor: id=7, shape=(3, 4), dtype=int32, numpy=
array([[ 1,  2,  3,  4],
       [ 5,  6,  7,  8],
       [ 9, 10, 11, 12]], dtype=int32)>


tf.pad(a, [[1,1],[3,0]], constant_values=3)  # 第一维度前后各填充一行,第二维度前填充两行,后不填充,填充值为3


<tf.Tensor: id=10, shape=(5, 7), dtype=int32, numpy=
array([[ 3,  3,  3,  3,  3,  3,  3],
       [ 3,  3,  3,  1,  2,  3,  4],
       [ 3,  3,  3,  5,  6,  7,  8],
       [ 3,  3,  3,  9, 10, 11, 12],
       [ 3,  3,  3,  3,  3,  3,  3]], dtype=int32)>


对于3维tensor,paddings是一个shape为(3,2)的list:


a = tf.reshape(a, [2, 2, 3])


a


<tf.Tensor: id=12, shape=(2, 2, 3), dtype=int32, numpy=
array([[[ 1,  2,  3],
        [ 4,  5,  6]],
       [[ 7,  8,  9],
        [10, 11, 12]]], dtype=int32)>


tf.pad(a, [[1, 0],[1,1],[1,0]])  # 第一维度前填充1块数据,后不填充,第二维度前后各填充1行,第三维度前填充1列,后不填充


<tf.Tensor: id=14, shape=(3, 4, 4), dtype=int32, numpy=
array([[[ 0,  0,  0,  0],
        [ 0,  0,  0,  0],
        [ 0,  0,  0,  0],
        [ 0,  0,  0,  0]],
       [[ 0,  0,  0,  0],
        [ 0,  1,  2,  3],
        [ 0,  4,  5,  6],
        [ 0,  0,  0,  0]],
       [[ 0,  0,  0,  0],
        [ 0,  7,  8,  9],
        [ 0, 10, 11, 12],
        [ 0,  0,  0,  0]]], dtype=int32)>


a = tf.range(1,13)


a = tf.reshape(a,[3,4])


a


<tf.Tensor: id=20, shape=(3, 4), dtype=int32, numpy=
array([[ 1,  2,  3,  4],
       [ 5,  6,  7,  8],
       [ 9, 10, 11, 12]], dtype=int32)>


当指定填充模式mode为'REFLECT'时,指的是以各维度边缘为对称轴进行填充(不包括边缘数据也就是对称轴本身),且填充的规模不能大于该维度原有规模-1:


tf.pad(a, [[2,1],[3,1]],mode='REFLECT')  # 对第二个维度填充时,如果大于3就回产生异常,因为3已经可以把第二维度所有数据复制一遍


<tf.Tensor: id=22, shape=(6, 8), dtype=int32, numpy=
array([[12, 11, 10,  9, 10, 11, 12, 11],
       [ 8,  7,  6,  5,  6,  7,  8,  7],
       [ 4,  3,  2,  1,  2,  3,  4,  3],
       [ 8,  7,  6,  5,  6,  7,  8,  7],
       [12, 11, 10,  9, 10, 11, 12, 11],
       [ 8,  7,  6,  5,  6,  7,  8,  7]], dtype=int32)>


SYMMETRIC填充模式与REFLECT填充模式一样,都是以边缘为对称轴进行赋值填充,不过SYMMETRIC模式会对对称轴进行赋值,所以指定的规模最大可以为原规模:


tf.pad(a, [[2,1],[4,1]],mode='SYMMETRIC')  # 这时候对第二个维度填充规模可以为4,但是超过4旧货产生异常


<tf.Tensor: id=24, shape=(6, 9), dtype=int32, numpy=
array([[ 8,  7,  6,  5,  5,  6,  7,  8,  8],
       [ 4,  3,  2,  1,  1,  2,  3,  4,  4],
       [ 4,  3,  2,  1,  1,  2,  3,  4,  4],
       [ 8,  7,  6,  5,  5,  6,  7,  8,  8],
       [12, 11, 10,  9,  9, 10, 11, 12, 12],
       [12, 11, 10,  9,  9, 10, 11, 12, 12]], dtype=int32)>


2 tile()


tile()方法对指定维度进行复制,定义如下:


tile(input, multiples, name=None):


  • input:需要复制的tensor
  • multiples: 各维度需要复制的次数,0表示去除数据,1表示不复制,2表示复制一次

参数multiples是一个长度与tensor的秩相等的list,例如当tensor的shape为(12,)时,multiples的shape也必须为只有一个元素的list,例如multiples=[2],表示对第一维度复制1次:


a = tf.range(12)


tf.tile(a,[2])


<tf.Tensor: id=33, shape=(24,), dtype=int32, numpy=
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,  0,  1,  2,  3,  4,
        5,  6,  7,  8,  9, 10, 11], dtype=int32)>


当tensor的shape为(3,4)时,multiples是一个包含两个元素的list:


a = tf.reshape(a, [3,4])


tf.tile(a, [2,3])  # 第一维度复制1次,第二维度复制2次


<tf.Tensor: id=37, shape=(6, 12), dtype=int32, numpy=
array([[ 0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3],
       [ 4,  5,  6,  7,  4,  5,  6,  7,  4,  5,  6,  7],
       [ 8,  9, 10, 11,  8,  9, 10, 11,  8,  9, 10, 11],
       [ 0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3],
       [ 4,  5,  6,  7,  4,  5,  6,  7,  4,  5,  6,  7],
       [ 8,  9, 10, 11,  8,  9, 10, 11,  8,  9, 10, 11]], dtype=int32)>


当tensor的shape为(2,2,3时,multiples是一个包含3个元素list:


a = tf.reshape(a, [2,2,3])


tf.tile(a, [2,1,2])


<tf.Tensor: id=41, shape=(4, 2, 6), dtype=int32, numpy=
array([[[ 0,  1,  2,  0,  1,  2],
        [ 3,  4,  5,  3,  4,  5]],
       [[ 6,  7,  8,  6,  7,  8],
        [ 9, 10, 11,  9, 10, 11]],
       [[ 0,  1,  2,  0,  1,  2],
        [ 3,  4,  5,  3,  4,  5]],
       [[ 6,  7,  8,  6,  7,  8],
        [ 9, 10, 11,  9, 10, 11]]], dtype=int32)>


相关文章
|
7月前
|
存储 PyTorch 算法框架/工具
PyTorch 中的 Tensor:属性、数据生成和基本操作
PyTorch 中的 Tensor:属性、数据生成和基本操作
223 0
|
TensorFlow 算法框架/工具
【tensorflow】TF1.x保存与读取.pb模型写法介绍
由于TF里面的概念比较接地气,所以用tf1.x保存.pb模型时总是怕有什么操作漏掉了,会造成保存的模型是缺少变量数据或者没有保存图,所以先明确一下:用TF1.x保存模型时只需要保存模型的输入输出的变量(多输入就保存多个),不需要保存中间的变量;用TF1.x加载模型时只需要加载保存的模型,然后读一下输入输出变量(多输入就读多个),不需要初始化(反而会重置掉变量的值)。
166 0
|
2月前
|
存储 并行计算 PyTorch
探索PyTorch:模型的定义和保存方法
探索PyTorch:模型的定义和保存方法
|
4月前
|
缓存 Linux TensorFlow
更改 TensorFlow Hub 模型的缓存位置
更改 TensorFlow Hub 模型的缓存位置
46 0
|
7月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
OpenCV读取tensorflow 2.X模型的方法:将SavedModel转为frozen graph
【2月更文挑战第22天】本文介绍基于Python的tensorflow库,将tensorflow与keras训练好的SavedModel格式神经网络模型转换为frozen graph格式,从而可以用OpenCV库在C++等其他语言中将其打开的方法~
153 1
OpenCV读取tensorflow 2.X模型的方法:将SavedModel转为frozen graph
|
7月前
|
PyTorch 算法框架/工具 异构计算
pytorch 模型保存与加载
pytorch 模型保存与加载
52 0
|
7月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
TensorFlow的保存与加载模型
【4月更文挑战第17天】本文介绍了TensorFlow中模型的保存与加载。保存模型能节省训练时间,便于部署和复用。在TensorFlow中,可使用`save_model_to_hdf5`保存模型结构,`save_weights`保存权重,或转换为SavedModel格式。加载时,通过`load_model`恢复结构,`load_weights`加载权重。注意模型结构一致性、环境依赖及自定义层的兼容性问题。正确保存和加载能有效利用模型资源,提升效率和准确性。
|
PyTorch 算法框架/工具 Python
pytorch保存参数及模型的两种方式
pytorch保存参数及模型的两种方式
611 0
|
算法 安全 TensorFlow
Tensorflow源码解析3 -- TensorFlow核心对象 - Graph
# 1 Graph概述 计算图Graph是TensorFlow的核心对象,TensorFlow的运行流程基本都是围绕它进行的。包括图的构建、传递、剪枝、按worker分裂、按设备二次分裂、执行、注销等。因此理解计算图Graph对掌握TensorFlow运行尤为关键。 # 2 默认Graph ### 默认图替换 之前讲解Session的时候就说过,一个Session只能r
3826 0