本文主要介绍K-means聚类模型原理及实践demo。
一、原理
K-means聚类是一种经典的、广泛使用的无监督学习算法,主要用于将数据集划分为多个类别或“簇”。其目标是将数据集中的每个点分配到K个聚类中心之一,使得簇内的点尽可能相似,而簇间的点尽可能不同。
K-means算法的基本步骤:
- 初始化:选择K个数据点作为初始聚类中心(质心)。
- 分配:将每个点分配到最近的聚类中心,形成K个簇。
- 更新:重新计算每个簇的聚类中心,通常是簇内所有点的均值。
- 迭代:重复步骤2和3,直到满足停止条件,如质心的变化小于某个阈值或达到预设的迭代次数。
K-means算法的关键点:
- K的选择:K的选择通常是基于经验或使用如肘部法则(Elbow Method)等方法确定的。
- 初始化方法:可以随机选择,也可以使用如K-means++等更高级的方法以提高性能。
- 收敛性:K-means算法在局部最优上是收敛的,可能不会找到全局最优解,因此可能需要多次运行以获得最佳结果。
- 性能度量:使用如轮廓系数(Silhouette Coefficient)等指标来评估聚类效果。
K-means算法的优缺点:
优点:
- 简单、直观,易于实现和理解。
- 训练速度快,适合处理大型数据集。
- 对于球形簇表现良好。
缺点:
- 对初始聚类中心敏感,可能导致局部最优解。
- 需要预先指定K值,但K值的选择通常不是显而易见的。
- 对噪声和异常值敏感。
- 只能发现球形簇,对于非球形簇可能效果不佳。
K-means聚类模型的应用场景:
- 图像分割
- 市场细分
- 异常检测
- 数据压缩
- 特征提取
K-means聚类是一种强大的工具,但需要根据具体问题和数据特性来适当使用。在实际应用中,可能需要与其他聚类算法或预处理步骤结合使用,以获得最佳效果。
二、举个栗子
使用scikit-learn中的内置数据集Iris来进行聚类。
预期效果
核心代码
# 导入必要的库
from sklearn.cluster import KMeans
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
# 加载Iris数据集
iris = load_iris()
X = iris.data
# 选择要使用的聚类数目,这里我们选择3个聚类
k = 3
# 初始化KMeans对象
kmeans = KMeans(n_clusters=k, random_state=42)
# 执行KMeans聚类
kmeans.fit(X)
# 输出聚类中心
centroids = kmeans.cluster_centers_
# 输出每个数据点的聚类标签
labels = kmeans.labels_
# 可视化聚类结果(这里我们取前两个特征进行可视化,因为它们是二维的)
plt.scatter(X[:, 0], X[:, 1], c=labels, s=50, cmap='viridis')
plt.scatter(centroids[:, 0], centroids[:, 1], c='red', s=200, alpha=0.75, marker='X')
plt.title('K-means Clustering of Iris Dataset')
plt.xlabel('Sepal Length')
plt.ylabel('Sepal Width')
plt.show()
Iris数据集是一个非常著名且被广泛使用的多变量数据集,用于测试统计算法和机器学习模型,如分类、聚类和回归。这个数据集包含了150个样本,每个样本有4个特征,这些特征描述了鸢尾花(Iris)的三个不同属(setosa, versicolor, virginica)的度量(测量)。
具体来说,Iris数据集的每个样本包括以下特征:
- 花萼长度(Sepal Length):花萼的最大长度,单位通常是厘米。
- 花萼宽度(Sepal Width):花萼的宽度,单位是厘米。
- 花瓣长度(Petal Length):花瓣的最大长度,单位是厘米。
- 花瓣宽度(Petal Width):花瓣的宽度,单位是厘米。
这些特征的测量值是浮点数,范围大致如下:
- 花萼长度:4.3cm至7.9cm
- 花萼宽度:2.0cm至4.4cm
- 花瓣长度:1.0cm至6.9cm
- 花瓣宽度:0.1cm至2.5cm
除了这些特征外,Iris数据集还包含了每个样本对应的真实类别标签,这使得它成为监督学习算法的绝佳数据集。然而,由于K-means是一种无监督学习算法,它不使用这些标签信息,而是试图根据数据的特征发现数据的内在结构。
Iris数据集由于其简单性、易于理解性以及包含有限数量的类别和特征,常被用作教学和算法测试的基准。它允许研究人员和学生在没有大量数据预处理的情况下,快速地测试和比较不同算法的性能。
三、自定义实例
使用自定义的Excel文档作为数据集进行K-means聚类
预期效果
核心代码
# 导入必要的库
import pandas as pd
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
# 加载CSV数据集
# 假设CSV文件有两列,分别是Sepal Length和Sepal Width
# 请根据你的CSV文件的实际列名进行调整
df = pd.read_csv('demoDB.csv')
X = df.values
# 选择要使用的聚类数目,这里我们选择3个聚类
k = 3
# 初始化KMeans对象
kmeans = KMeans(n_clusters=k, random_state=42)
# 执行KMeans聚类
kmeans.fit(X)
# 输出聚类中心
centroids = kmeans.cluster_centers_
# 输出每个数据点的聚类标签
labels = kmeans.labels_
# 可视化聚类结果(这里我们取前两个特征进行可视化)
plt.scatter(X[:, 0], X[:, 1], c=labels, s=50, cmap='viridis')
plt.scatter(centroids[:, 0], centroids[:, 1], c='red', s=200, alpha=0.75, marker='X')
plt.title('K-means Clustering of Custom Dataset')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.show()
数据源
demoDB.csv
解决方案
K-means聚类算法可以应用于生活中的许多实际问题,尤其是在需要将数据分组或分类,但又没有明确分组标签的情况下。以下是一些例子,展示了如何使用K-means聚类算法解决实际问题:
1. 市场细分
企业经常使用K-means聚类来对客户进行细分,以便更好地了解他们的行为和偏好。通过分析客户的购买历史、年龄、性别和收入等特征,K-means可以帮助企业识别不同的客户群体,并为每个群体定制营销策略。
2. 社交网络分析
在社交网络分析中,K-means可以用来识别社区结构,即在社交网络中分组紧密连接的用户。通过分析用户的互动、兴趣和行为,K-means可以揭示社交网络中的不同社区。
3. 基因表达分析
在生物信息学中,K-means聚类可以用于基因表达数据的分析,以识别具有相似表达模式的基因。这有助于理解不同基因的功能和它们在疾病中的作用。
4. 图像压缩
K-means聚类可以用于图像压缩技术,如颜色量化。通过将图像的颜色聚类为几个代表颜色,K-means可以减少图像文件的大小,同时尽量保持其视觉质量。
5. 异常检测
在许多领域,如金融交易、网络安全或工业系统监控中,K-means可以用来检测异常或欺诈行为。通过分析正常行为的模式,K-means可以识别那些不符合常规模式的异常点。
应用实例:市场细分
假设我们想要使用K-means聚类算法对客户进行细分。以下是基于前面提供的代码模板,针对市场细分问题的示例:
# 导入必要的库
import pandas as pd
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
# 加载CSV数据集
# 假设CSV文件包含了客户的年龄、收入和购买频率等特征
df = pd.read_csv('customer_data.csv')
X = df.values # 假设所有列都是数值型特征
# 选择要使用的聚类数目,这里我们选择3个聚类,根据业务需求调整
k = 3
# 初始化KMeans对象
kmeans = KMeans(n_clusters=k, random_state=42)
# 执行KMeans聚类
kmeans.fit(X)
# 输出聚类中心
centroids = kmeans.cluster_centers_
# 输出每个数据点的聚类标签
labels = kmeans.labels_
# 可视化聚类结果,这里我们取年龄和收入进行可视化
plt.figure(figsize=(10, 6))
plt.scatter(X[:, 0], X[:, 1], c=labels, s=50, cmap='viridis') # 假设第0列是年龄,第1列是收入
plt.scatter(centroids[:, 0], centroids[:, 1], c='red', s=200, alpha=0.75, marker='X')
plt.title('K-means Clustering for Customer Segmentation')
plt.xlabel('Age')
plt.ylabel('Income')
plt.show()