数据分析入门系列教程-K-Means实战

简介: 数据分析入门系列教程-K-Means实战

上一节我们讲解了 K-Means 算法的原理,并且手动实现了一个 K-Means 算法函数,今天我们一起来完成相关的实战内容。

在 Sklearn 中使用 K-Means

Sklearn 同样提供了非常完善的 K-Means 算法实现

from sklearn.cluster import KMeans
kmeans = KMeans()

再来看下可以传递给该类的主要参数都有哪些

参数 解释
n_clusters 即 K 值,默认为8
max_iter 最大迭代次数,如果聚类很难收敛的话,可以通过设置该参数来停止算法
n_init 初始化中心点的运算次数,默认为10。这里sklearn 会自动为我们进行迭代运算,找出合适的初始中心点。
init 初始值的选择方式。默认就是采用 k-means++ 的方式,或者还可以采用 random 完全随机的方式。

下面我们就进入实战部分,完整代码可以查看这里

https://github.com/zhouwei713/DataAnalyse/tree/master/K-Means

图片分割

首先我们先来看下图片


图片整体是两只不同颜色的鞋子,背景是深色的椅子和桌子。

首先是读取图片

读取图片,我们使用 PIL 库,这是一个很好的图片处理工具库

安装 PIL 库

pip install Pillow

通过 PIL 读取图片

from PIL import Image
img = Image.open('foot-small.jpg')
print(img)
>>>
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=559x497 at 0x15541750>

可以看到通过 image 的 open 函数可以读取图片,得到的是一个 PIL 的对象

接下来获取图片每个点的三通道值

width, height = img.size
data = []
for x in range(width):
    for y in range(height):
        r, g, b = img.getpixel((x, y))
        data.append([r, g, b])

我们知道,jpg 图片的每个像素点都是由 (r, g, b)这三个值组成的,称为通道值

接下来在进行数据规范化,可以加快聚类的收敛

mm = preprocessing.MinMaxScaler()
img_data = mm.fit_transform(data)
img_mat = np.mat(img_data)
print(img_mat)
>>>
[[0.69411765 0.76470588 0.85098039]
 [0.69803922 0.76862745 0.85490196]
 [0.70196078 0.77254902 0.85882353]
 ...
 [0.42352941 0.42352941 0.43137255]
 [0.48235294 0.48235294 0.49019608]
 [0.42745098 0.42352941 0.44313725]]

现在开始使用 K-Means 做聚类

kmeans = KMeans(n_clusters=2)
kmeans.fit(img_mat)
label = kmeans.predict(img_mat)

predict 得到的就是聚类的结果

最后再把得到的聚类结果赋值给新的图片

label = label.reshape([width, height])picture_mark = Image.new("L", (width, height))for x in range(width):
    for y in range(height):
        picture_mark.putpixel((x, y), int(256/(label[x][y]+1))-1)picture_mark.save("new-foot.jpg", "JPEG")

因为我们最终得到的 label 是0和1,所以还要手动转换成灰度值,把 label=0的设置成了255,把 label=1的设置成了127。


图片压缩

其实我们还可以使用 k-Means 算法来压缩图片,这次我们使用 matplotlib 来处理图片

import matplotlib.pyplot as plt
from numpy import reshape,uint8,flipud
from sklearn.cluster import KMeans
from copy import deepcopy
from PIL import Image

首先还是读取图片

img = plt.imread('foot-small.jpg')
print(img.shape)
>>>
(497, 559, 3)

可以看到,这张图片是497X559像素的

接下来我们通过 reshape 函数,重置下该矩阵的形状

pixel = reshape(img,(img.shape[0]*img.shape[1],3))
print(pixel.shape)
>>>
(277823, 3)

现在就转换成了一个二维的矩阵

进行 K-Means 训练并获取聚类类别和中心点

pixel_new = deepcopy(pixel)model = KMeans(n_clusters=5)
labels = model.fit_predict(pixel)
palette = model.cluster_centers_

cluster_centers_ 属性保存的是每个聚类中心点的坐标

接下来再把每个类别其他点的像素值替换成中心点的值

for i in range(len(pixel)):
    pixel_new[i,:] = palette[labels[i]]

最后保存图片

new_pic = reshape(pixel_new, (img.shape[0], img.shape[1],3))
images = Image.fromarray(new_pic)
images.save("foot-new.jpg")


可以看到,通过这种方法,可以很好的保留图片原有的颜色信息,并且大大减少了图片保存的数据量。

同时你应该也注意到了,我们在初始化 K-Means 类时,只是指定了 n_clusters 参数,对于 init 参数我们使用的是默认值,即 k-means++,所以对于使用 sklearn 工具来说,我们已经在选择初始点时进行了优化处理。

足球队聚类

下面我们再来看看如何对足球队进行聚类划分,其实如果和球队类推到人,那么就是对人的聚类划分,这个在营销领域就是非常常用的营销分析方法了。

我们先来看一下数据情况

team_data = pd.read_csv('football-team.csv')
print(team_data)
>>>
        国家  2019年国际排名  2019亚洲杯  2015亚洲杯
0       中国         73        8        7
1       日本         60        2        5
2       韩国         61        6        2
3       伊朗         34        3        6
4       沙特         67       15       10
5      伊拉克         91       14        4
6      卡塔尔        101        1       13
7      阿联酋         81        4        6
8   乌兹别克斯坦         88       10        8
9       泰国        122       13       17
10      越南        102        7       17
11      阿曼         87       16       12
12      巴林        116       12       11
13      朝鲜        110       24       14
14      澳洲         40        5        1
15     叙利亚         76       20       17
16      约旦        118        9        9
17    巴勒斯坦         96       18       16

我分别提取了上面18支亚洲足球队的当前世界排名以及2019,2015年亚洲杯的排名。

下面我们提取需要训练的数据

train_X = team_data[['2019年国际排名', '2019亚洲杯', '2015亚洲杯']]
print(train_X)
>>>
    2019年国际排名  2019亚洲杯  2015亚洲杯
0          73        8        7
1          60        2        5
2          61        6        2
3          34        3        6
4          67       15       10
5          91       14        4
6         101        1       13
7          81        4        6
8          88       10        8
9         122       13       17
10        102        7       17
11         87       16       12
12        116       12       11
13        110       24       14
14         40        5        1
15         76       20       17
16        118        9        9
17         96       18       16

因为数据数值大小的差异还是比较大的,所以需要使用数据规范化的方式,把训练数据规范化

from sklearn import preprocessing
mm = preprocessing.MinMaxScaler()
train_x = mm.fit_transform(train_X)

下面就是使用 K-Means 做聚类,我们先把类别设置为3,看看效果

kmeans = KMeans(n_clusters=3)# kmeans 算法
kmeans.fit(train_x)
predict_y = kmeans.predict(train_x)
# 合并聚类结果,插入到原数据中
result = pd.concat((team_data,pd.DataFrame(predict_y)),axis=1)
result.rename({0:u'聚类'},axis=1,inplace=True)
print(result)
>>>
        国家  2019年国际排名  2019亚洲杯  2015亚洲杯  聚类
0       中国         73        8        7   1
1       日本         60        2        5   1
2       韩国         61        6        2   1
3       伊朗         34        3        6   1
4       沙特         67       15       10   0
5      伊拉克         91       14        4   1
6      卡塔尔        101        1       13   2
7      阿联酋         81        4        6   1
8   乌兹别克斯坦         88       10        8   1
9       泰国        122       13       17   2
10      越南        102        7       17   2
11      阿曼         87       16       12   0
12      巴林        116       12       11   2
13      朝鲜        110       24       14   0
14      澳洲         40        5        1   1
15     叙利亚         76       20       17   0
16      约旦        118        9        9   2
17    巴勒斯坦         96       18       16   0

可以看到,我国和日本,韩国,伊朗,澳大利亚等聚类到了一起,通过这些年的观赛经验,这个还是有挺大误差的。

现在我们使用手肘法来确定最佳的 K 值

SS = []for k in range(2, 10):
    kmeans = KMeans(n_clusters=k).fit(train_x)
    SS.append(kmeans.inertia_)
plt.plot(range(2,10), SS)
plt.xlabel('K')
plt.ylabel('SS')

inertia_ 属性是每个点到聚类中心的聚类之和


通过上面的得出的图片,我们可以清楚的看到拐点是在 k=4的地方,所以我们选取k=4作为聚类的种类数量,再次重新聚类

kmeans = KMeans(n_clusters=4)# kmeans 算法
kmeans.fit(train_x)
predict_y = kmeans.predict(train_x)
# 合并聚类结果,插入到原数据中
result = pd.concat((team_data,pd.DataFrame(predict_y)),axis=1)
result.rename({0:u'聚类'},axis=1,inplace=True)
print(result)
>>>
        国家  2019年国际排名  2019亚洲杯  2015亚洲杯  聚类
0       中国         73        8        7   0
1       日本         60        2        5   2
2       韩国         61        6        2   2
3       伊朗         34        3        6   2
4       沙特         67       15       10   0
5      伊拉克         91       14        4   0
6      卡塔尔        101        1       13   3
7      阿联酋         81        4        6   0
8   乌兹别克斯坦         88       10        8   0
9       泰国        122       13       17   3
10      越南        102        7       17   3
11      阿曼         87       16       12   1
12      巴林        116       12       11   3
13      朝鲜        110       24       14   1
14      澳洲         40        5        1   2
15     叙利亚         76       20       17   1
16      约旦        118        9        9   3
17    巴勒斯坦         96       18       16   1

现在可以看出,韩日,伊朗和澳大利亚仍然聚类在一起,而中国则是与沙特,伊拉克等国在一起,现在的结果应该更加符合当前亚洲足球的整体水平了。

当然,你还可能发现,如果你运行多次 K-Means 算法,会得到不同的结果,这个就是上一节讲的,因为每次运行算法,初始值都是不同的,而不同的初始值,会得到不同的聚类结果。

总结

今天我们通过两个实战例子,再次加深了对于 K-Means 算法的理解,希望你可以结合代码,再好好的体会下。

K-Means 是无监督学习领域一个非常重要的算法,对于用户分层等领域都有很好的应用。

当然 K-Means 算法的缺点也十分明显,就是聚类个数 K 值需要提前指定,如果我们不知道当前要聚类成多少个类别,那么我们就需要多给几个 K 值,然后从中找出聚类效果最好的那个。


练习题

我在 GitHub 上还上传了一个像素比较高的图片 foot.jpg,如果你的电脑比较好,是否可以尝试着对该图片进行相关的压缩操作呢?

相关文章
|
25天前
|
机器学习/深度学习 数据可视化 数据挖掘
使用Python进行数据分析的入门指南
本文将引导读者了解如何使用Python进行数据分析,从安装必要的库到执行基础的数据操作和可视化。通过本文的学习,你将能够开始自己的数据分析之旅,并掌握如何利用Python来揭示数据背后的故事。
|
1月前
|
机器学习/深度学习 数据可视化 数据挖掘
使用Python进行数据分析的入门指南
【10月更文挑战第42天】本文是一篇技术性文章,旨在为初学者提供一份关于如何使用Python进行数据分析的入门指南。我们将从安装必要的工具开始,然后逐步介绍如何导入数据、处理数据、进行数据可视化以及建立预测模型。本文的目标是帮助读者理解数据分析的基本步骤和方法,并通过实际的代码示例来加深理解。
54 3
|
1月前
|
消息中间件 数据挖掘 Kafka
Apache Kafka流处理实战:构建实时数据分析应用
【10月更文挑战第24天】在当今这个数据爆炸的时代,能够快速准确地处理实时数据变得尤为重要。无论是金融交易监控、网络行为分析还是物联网设备的数据收集,实时数据处理技术都是不可或缺的一部分。Apache Kafka作为一款高性能的消息队列系统,不仅支持传统的消息传递模式,还提供了强大的流处理能力,能够帮助开发者构建高效、可扩展的实时数据分析应用。
91 5
|
27天前
|
数据可视化 数据挖掘
R中单细胞RNA-seq数据分析教程 (3)
R中单细胞RNA-seq数据分析教程 (3)
31 3
R中单细胞RNA-seq数据分析教程 (3)
|
1月前
|
SQL 数据挖掘 Python
R中单细胞RNA-seq数据分析教程 (1)
R中单细胞RNA-seq数据分析教程 (1)
39 5
R中单细胞RNA-seq数据分析教程 (1)
|
1月前
|
机器学习/深度学习 数据挖掘
R中单细胞RNA-seq数据分析教程 (2)
R中单细胞RNA-seq数据分析教程 (2)
49 0
R中单细胞RNA-seq数据分析教程 (2)
|
1月前
|
数据采集 数据可视化 数据挖掘
深入浅出:使用Python进行数据分析的基础教程
【10月更文挑战第41天】本文旨在为初学者提供一个关于如何使用Python语言进行数据分析的入门指南。我们将通过实际案例,了解数据处理的基本步骤,包括数据的导入、清洗、处理、分析和可视化。文章将用浅显易懂的语言,带领读者一步步掌握数据分析师的基本功,并在文末附上完整的代码示例供参考和实践。
|
1月前
|
并行计算 数据挖掘 大数据
Python数据分析实战:利用Pandas处理大数据集
Python数据分析实战:利用Pandas处理大数据集
|
2月前
|
数据采集 机器学习/深度学习 数据可视化
深入浅出:用Python进行数据分析的入门指南
【10月更文挑战第21天】 在信息爆炸的时代,掌握数据分析技能就像拥有一把钥匙,能够解锁隐藏在庞大数据集背后的秘密。本文将引导你通过Python语言,学习如何从零开始进行数据分析。我们将一起探索数据的收集、处理、分析和可视化等步骤,并最终学会如何利用数据讲故事。无论你是编程新手还是希望提升数据分析能力的专业人士,这篇文章都将为你提供一条清晰的学习路径。
|
2月前
|
数据挖掘 索引 Python
Python数据分析篇--NumPy--入门
Python数据分析篇--NumPy--入门
42 0