张量排序
TensorFlow中,张量排序的操作主要包括:
tf.sort():按照升序或者降序对张量进行排序,返回排序后的张量。
tf.argsort():按照升序或者降序对张量进行排序,但返回的是索引。
tf.nn.top_k():返回前k个最大值。
tf.sort/argsort(input, direction, axis):
input:输入张量;
direction:排列顺序,可为DESCENDING 降序或者ASCENDING(升序)。默认为ASCENDING(升序);
axis:按照axis维度进行排序。默认 axis=-1 最后一个维度。
代码:
sort_sample_1 = tf.random.shuffle(tf.range(10))
print("输入张量:",sort_sample_1.numpy())
sorted_sample_1 = tf.sort(sort_sample_1, direction="ASCENDING")
print("生序排列后的张量:",sorted_sample_1.numpy())
sorted_sample_2 = tf.argsort(sort_sample_1,direction="ASCENDING")
print("生序排列后,元素的索引:",sorted_sample_2.numpy())
输出:
输入张量: [1 8 7 9 6 5 4 2 3 0]
生序排列后的张量: [0 1 2 3 4 5 6 7 8 9]
生序排列后,元素的索引: [9 0 7 8 6 5 4 2 1 3]
tf.nn.top_k(input,K,sorted=TRUE):
input:输入张量;
K:需要输出的前k个值及其索引。
sorted: sorted=TRUE表示升序排列;sorted=FALSE表示降序排列。
返回两个张量:
values:也就是每一行的最大的k个数字
indices:这里的下标是在输入的张量的最后一个维度的下标
代码:
values, index = tf.nn.top_k(sort_sample_1,5)
print("输入张量:",sort_sample_1.numpy())
print("升序排列后的前5个数值:", values.numpy())
print("升序排列后的前5个数值的索引:", index.numpy())
输出:
输入张量: [1 8 7 9 6 5 4 2 3 0]
升序排列后的前5个数值: [9 8 7 6 5]
升序排列后的前5个数值的索引: [3 1 2 4 5]