张量的拼接
TensorFlow中,张量拼接的操作主要包括:
tf.contact():将向量按指定维连起来,其余维度不变。
tf.stack() :将一组R维张量变为R+1维张量,拼接前后维度变化。
tf.concat(values, axis, name='concat'):
values:输入张量;
axis:指定拼接维度;
name:操作名称。
代码:
concat_sample_1 = tf.random.normal([4,100,100,3])
concat_sample_2 = tf.random.normal([40,100,100,3])
print("原始数据的尺寸分别为:",concat_sample_1.shape,concat_sample_2.shape)
concated_sample_1 = tf.concat([concat_sample_1,concat_sample_2],axis=0)
print("拼接后数据的尺寸:",concated_sample_1.shape)
输出:
原始数据的尺寸分别为: (4, 100, 100, 3) (40, 100, 100, 3)
拼接后数据的尺寸: (44, 100, 100, 3)
在原来矩阵基础上增加了一个维度,也是同样的道理,axis决定维度增加的位置。
tf.stack(values, axis=0, name='stack'):
values:输入张量;一组相同形状和数据类型的张量。
axis:指定拼接维度;
name:操作名称。
代码:
stack_sample_1 = tf.random.normal([100,100,3])
stack_sample_2 = tf.random.normal([100,100,3])
print("原始数据的尺寸分别为:",stack_sample_1.shape, stack_sample_2.shape)
拼接后维度增加。axis=0,则在第一个维度前增加维度。
stacked_sample_1 = tf.stack([stack_sample_1, stack_sample_2],axis=0)
print("拼接后数据的尺寸:",stacked_sample_1.shape)
输出:
原始数据的尺寸分别为: (100, 100, 3) (100, 100, 3)
拼接后数据的尺寸: (2, 100, 100, 3)