近日,做图像聚类,使用KMeans算法的过程中,遇到了一个写法:
# 这里centers是得到的质点集合,是一个矩阵,labels是得到的标签,也是一个矩阵 res = centers[labels.flatten()]
所以我就很好奇,centers是K*m的,labels是n*1的,得到的结果怎么会变成n*m的呢?
而且,关键是我搞不清楚这是numpy的用法,还是cv2的特殊用法,但从格式上来看,我觉得还是numpy的用法,因此,我写了一段测试代码。
import numpy as np centers=np.array([[1,2,3],[4,5,6],[7,8,9]],dtype='int32') labels=np.array([[1],[1],[1],[2],[2],[2],[0],[0],[0],[1],[1],[1],[2],[2],[2],[0],[0],[0]],dtype='int32') res=centers[labels] print(res)
这里解释一些:
centers是3*3的矩阵,labels是聚类得到的全部数据的标签,是n(数据量)*1的矩阵,这里的n只跟数据量有关,但其取值,因为是类别索引,因此必须<K。
这里的res则是得到n*3的矩阵
因此ndarray1[ndarray2]这种写法的含义,从我的理解是:
遍历ndarray2,以ndarray2的值作为索引取ndarray1的相应值,最终得到的数据条数则是和ndarray2的数相等,至于每条数据有几个值,则是又ndarray1决定的
主要是我没查到numpy的定义,因此这只是我自己浅显易懂的理解,而且,我不确定是不是这种写法有特定的条件,比如像这里,ndarray2得到的值是ndarray1的索引。这就有待再以后的使用中注意了。有明白的大佬,可以留言指点一下,再此先谢过了。