数据分析入门系列教程-K-Means实战

简介: 上一节我们讲解了 K-Means 算法的原理,并且手动实现了一个 K-Means 算法函数,今天我们一起来完成相关的实战内容。

在Sklearn中使用K-Means


Sklearn 同样提供了非常完善的 K-Means 算法实现

from sklearn.cluster import KMeans
kmeans = KMeans()

再来看下可以传递给该类的主要参数都有哪些

参数 解释
n_clusters 即 K 值,默认为8
max_iter 最大迭代次数,如果聚类很难收敛的话,可以通过设置该参数来停止算法
n_init 初始化中心点的运算次数,默认为10。这里sklearn 会自动为我们进行迭代运算,找出合适的初始中心点。
init 初始值的选择方式。默认就是采用 k-means++ 的方式,或者还可以采用 random 完全随机的方式。

下面我们就进入实战部分,完整代码可以查看这里

https://github.com/zhouwei713/DataAnalyse/tree/master/K-Means


图片分割


首先我们先来看下图片


image.png

图片整体是两只不同颜色的鞋子,背景是深色的椅子和桌子。

首先是读取图片

读取图片,我们使用 PIL 库,这是一个很好的图片处理工具库

安装 PIL 库

pip install Pillow

通过 PIL 读取图片

from PIL import Image
img = Image.open('foot-small.jpg')
print(img)
>>>
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=559x497 at 0x15541750>

可以看到通过 image 的 open 函数可以读取图片,得到的是一个 PIL 的对象

接下来获取图片每个点的三通道值

width, height = img.size
data = []
for x in range(width):
    for y in range(height):
        r, g, b = img.getpixel((x, y))
        data.append([r, g, b])

我们知道,jpg 图片的每个像素点都是由 (r, g, b)这三个值组成的,称为通道值

接下来在进行数据规范化,可以加快聚类的收敛

mm = preprocessing.MinMaxScaler()
img_data = mm.fit_transform(data)
img_mat = np.mat(img_data)
print(img_mat)
>>>
[[0.69411765 0.76470588 0.85098039]
 [0.69803922 0.76862745 0.85490196]
 [0.70196078 0.77254902 0.85882353]
 ...
 [0.42352941 0.42352941 0.43137255]
 [0.48235294 0.48235294 0.49019608]
 [0.42745098 0.42352941 0.44313725]]

现在开始使用 K-Means 做聚类

kmeans = KMeans(n_clusters=2)
kmeans.fit(img_mat)
label = kmeans.predict(img_mat)

predict 得到的就是聚类的结果

最后再把得到的聚类结果赋值给新的图片

label = label.reshape([width, height])picture_mark = Image.new("L", (width, height))for x in range(width):
    for y in range(height):
        picture_mark.putpixel((x, y), int(256/(label[x][y]+1))-1)picture_mark.save("new-foot.jpg", "JPEG")

因为我们最终得到的 label 是0和1,所以还要手动转换成灰度值,把 label=0的设置成了255,把 label=1的设置成了127。


image.png

图片压缩


其实我们还可以使用 k-Means 算法来压缩图片,这次我们使用 matplotlib 来处理图片

import matplotlib.pyplot as plt
from numpy import reshape,uint8,flipud
from sklearn.cluster import KMeans
from copy import deepcopy
from PIL import Image

首先还是读取图片

img = plt.imread('foot-small.jpg')
print(img.shape)
>>>
(497, 559, 3)

可以看到,这张图片是497X559像素的

接下来我们通过 reshape 函数,重置下该矩阵的形状

pixel = reshape(img,(img.shape[0]*img.shape[1],3))
print(pixel.shape)
>>>
(277823, 3)

现在就转换成了一个二维的矩阵

进行 K-Means 训练并获取聚类类别和中心点

pixel_new = deepcopy(pixel)model = KMeans(n_clusters=5)
labels = model.fit_predict(pixel)
palette = model.cluster_centers_

cluster_centers_ 属性保存的是每个聚类中心点的坐标

接下来再把每个类别其他点的像素值替换成中心点的值

for i in range(len(pixel)):
    pixel_new[i,:] = palette[labels[i]]

最后保存图片

new_pic = reshape(pixel_new, (img.shape[0], img.shape[1],3))
images = Image.fromarray(new_pic)
images.save("foot-new.jpg")


image.png

可以看到,通过这种方法,可以很好的保留图片原有的颜色信息,并且大大减少了图片保存的数据量。

同时你应该也注意到了,我们在初始化 K-Means 类时,只是指定了 n_clusters 参数,对于 init 参数我们使用的是默认值,即 k-means++,所以对于使用 sklearn 工具来说,我们已经在选择初始点时进行了优化处理。


足球队聚类


下面我们再来看看如何对足球队进行聚类划分,其实如果和球队类推到人,那么就是对人的聚类划分,这个在营销领域就是非常常用的营销分析方法了。

我们先来看一下数据情况

team_data = pd.read_csv('football-team.csv')
print(team_data)
>>>
        国家  2019年国际排名  2019亚洲杯  2015亚洲杯
0       中国         73        8        7
1       日本         60        2        5
2       韩国         61        6        2
3       伊朗         34        3        6
4       沙特         67       15       10
5      伊拉克         91       14        4
6      卡塔尔        101        1       13
7      阿联酋         81        4        6
8   乌兹别克斯坦         88       10        8
9       泰国        122       13       17
10      越南        102        7       17
11      阿曼         87       16       12
12      巴林        116       12       11
13      朝鲜        110       24       14
14      澳洲         40        5        1
15     叙利亚         76       20       17
16      约旦        118        9        9
17    巴勒斯坦         96       18       16

我分别提取了上面18支亚洲足球队的当前世界排名以及2019,2015年亚洲杯的排名。

下面我们提取需要训练的数据

train_X = team_data[['2019年国际排名', '2019亚洲杯', '2015亚洲杯']]
print(train_X)
>>>
    2019年国际排名  2019亚洲杯  2015亚洲杯
0          73        8        7
1          60        2        5
2          61        6        2
3          34        3        6
4          67       15       10
5          91       14        4
6         101        1       13
7          81        4        6
8          88       10        8
9         122       13       17
10        102        7       17
11         87       16       12
12        116       12       11
13        110       24       14
14         40        5        1
15         76       20       17
16        118        9        9
17         96       18       16

因为数据数值大小的差异还是比较大的,所以需要使用数据规范化的方式,把训练数据规范化

from sklearn import preprocessing
mm = preprocessing.MinMaxScaler()
train_x = mm.fit_transform(train_X)

下面就是使用 K-Means 做聚类,我们先把类别设置为3,看看效果

kmeans = KMeans(n_clusters=3)# kmeans 算法
kmeans.fit(train_x)
predict_y = kmeans.predict(train_x)
# 合并聚类结果,插入到原数据中
result = pd.concat((team_data,pd.DataFrame(predict_y)),axis=1)
result.rename({0:u'聚类'},axis=1,inplace=True)
print(result)
>>>
        国家  2019年国际排名  2019亚洲杯  2015亚洲杯  聚类
0       中国         73        8        7   1
1       日本         60        2        5   1
2       韩国         61        6        2   1
3       伊朗         34        3        6   1
4       沙特         67       15       10   0
5      伊拉克         91       14        4   1
6      卡塔尔        101        1       13   2
7      阿联酋         81        4        6   1
8   乌兹别克斯坦         88       10        8   1
9       泰国        122       13       17   2
10      越南        102        7       17   2
11      阿曼         87       16       12   0
12      巴林        116       12       11   2
13      朝鲜        110       24       14   0
14      澳洲         40        5        1   1
15     叙利亚         76       20       17   0
16      约旦        118        9        9   2
17    巴勒斯坦         96       18       16   0

可以看到,我国和日本,韩国,伊朗,澳大利亚等聚类到了一起,通过这些年的观赛经验,这个还是有挺大误差的。

现在我们使用手肘法来确定最佳的 K 值

SS = []for k in range(2, 10):
    kmeans = KMeans(n_clusters=k).fit(train_x)
    SS.append(kmeans.inertia_)
plt.plot(range(2,10), SS)
plt.xlabel('K')
plt.ylabel('SS')

inertia_ 属性是每个点到聚类中心的聚类之和


image.png

通过上面的得出的图片,我们可以清楚的看到拐点是在 k=4的地方,所以我们选取k=4作为聚类的种类数量,再次重新聚类

kmeans = KMeans(n_clusters=4)# kmeans 算法
kmeans.fit(train_x)
predict_y = kmeans.predict(train_x)
# 合并聚类结果,插入到原数据中
result = pd.concat((team_data,pd.DataFrame(predict_y)),axis=1)
result.rename({0:u'聚类'},axis=1,inplace=True)
print(result)
>>>
        国家  2019年国际排名  2019亚洲杯  2015亚洲杯  聚类
0       中国         73        8        7   0
1       日本         60        2        5   2
2       韩国         61        6        2   2
3       伊朗         34        3        6   2
4       沙特         67       15       10   0
5      伊拉克         91       14        4   0
6      卡塔尔        101        1       13   3
7      阿联酋         81        4        6   0
8   乌兹别克斯坦         88       10        8   0
9       泰国        122       13       17   3
10      越南        102        7       17   3
11      阿曼         87       16       12   1
12      巴林        116       12       11   3
13      朝鲜        110       24       14   1
14      澳洲         40        5        1   2
15     叙利亚         76       20       17   1
16      约旦        118        9        9   3
17    巴勒斯坦         96       18       16   1

现在可以看出,韩日,伊朗和澳大利亚仍然聚类在一起,而中国则是与沙特,伊拉克等国在一起,现在的结果应该更加符合当前亚洲足球的整体水平了。

当然,你还可能发现,如果你运行多次 K-Means 算法,会得到不同的结果,这个就是上一节讲的,因为每次运行算法,初始值都是不同的,而不同的初始值,会得到不同的聚类结果。


总结


今天我们通过两个实战例子,再次加深了对于 K-Means 算法的理解,希望你可以结合代码,再好好的体会下。

K-Means 是无监督学习领域一个非常重要的算法,对于用户分层等领域都有很好的应用。

当然 K-Means 算法的缺点也十分明显,就是聚类个数 K 值需要提前指定,如果我们不知道当前要聚类成多少个类别,那么我们就需要多给几个 K 值,然后从中找出聚类效果最好的那个。


微信图片_20220520182554.png

练习题


我在 GitHub 上还上传了一个像素比较高的图片 foot.jpg,如果你的电脑比较好,是否可以尝试着对该图片进行相关的压缩操作呢?

相关文章
|
1月前
|
自然语言处理 小程序 数据挖掘
数据分析实战-Python实现博客评论数据的情感分析
数据分析实战-Python实现博客评论数据的情感分析
104 0
|
2月前
|
数据采集 存储 数据挖掘
Python 爬虫实战之爬拼多多商品并做数据分析
Python爬虫可以用来抓取拼多多商品数据,并对这些数据进行数据分析。以下是一个简单的示例,演示如何使用Python爬取拼多多商品数据并进行数据分析。
|
2月前
|
数据采集 存储 数据可视化
Python数据分析从入门到实践
Python数据分析从入门到实践
|
9天前
|
供应链 搜索推荐 数据挖掘
Pandas实战案例:电商数据分析的实践与挑战
【4月更文挑战第16天】本文通过一个电商数据分析案例展示了Pandas在处理销售数据、用户行为分析及商品销售趋势预测中的应用。在数据准备与清洗阶段,Pandas用于处理缺失值、重复值。接着,通过用户购买行为和商品销售趋势分析,构建用户画像并预测销售趋势。实践中遇到的大数据量和数据多样性挑战,通过分布式计算和数据标准化解决。未来将继续深入研究Pandas与其他先进技术的结合,提升决策支持能力。
|
9天前
|
存储 数据可视化 数据挖掘
实战案例:Pandas在金融数据分析中的应用
【4月更文挑战第16天】本文通过实例展示了Pandas在金融数据分析中的应用。案例中,一家投资机构使用Pandas加载、清洗股票历史价格数据,删除无关列并重命名,将日期设为索引。接着,数据被可视化以观察价格走势,进行基本统计分析了解价格分布,以及计算移动平均线来平滑波动。Pandas的便捷功能在金融数据分析中体现出高效率和实用性。
|
18天前
|
机器学习/深度学习 数据可视化 数据挖掘
利用Python进行数据分析与可视化:从入门到精通
本文将介绍如何使用Python语言进行数据分析与可视化,从基础概念到高级技巧一应俱全。通过学习本文,读者将掌握Python在数据处理、分析和可视化方面的核心技能,为实际项目应用打下坚实基础。
|
2月前
|
数据采集 存储 数据挖掘
Python 爬虫实战之爬拼多多商品并做数据分析
在上面的代码中,我们使用pandas库创建DataFrame存储商品数据,并计算平均价格和平均销量。最后,我们将计算结果打印出来。此外,我们还可以使用pandas库提供的其他函数和方法来进行更复杂的数据分析和处理。 需要注意的是,爬取拼多多商品数据需要遵守拼多多的使用协议和规定,避免过度请求和滥用数据。
|
2月前
|
机器学习/深度学习 数据可视化 数据挖掘
Python数据分析:从入门到实践
Python数据分析:从入门到实践
|
3月前
|
SQL 数据挖掘 数据库
SQL数据分析实战:从导入到高级查询的完整指南
SQL数据分析实战:从导入到高级查询的完整指南
61 0
|
6天前
|
机器学习/深度学习 数据挖掘 计算机视觉
python数据分析工具SciPy
【4月更文挑战第15天】SciPy是Python的开源库,用于数学、科学和工程计算,基于NumPy扩展了优化、线性代数、积分、插值、特殊函数、信号处理、图像处理和常微分方程求解等功能。它包含优化、线性代数、积分、信号和图像处理等多个模块。通过SciPy,可以方便地执行各种科学计算任务。例如,计算高斯分布的PDF,需要结合NumPy使用。要安装SciPy,可以使用`pip install scipy`命令。这个库极大地丰富了Python在科学计算领域的应用。
12 1