阿旭机器学习实战【3】KNN算法进行年收入预测

简介: 阿旭机器学习实战【3】KNN算法进行年收入预测

问题描述


使用KNN算法训练模型,然后使用模型预测一个人的年收入是否大于50。


读取数据集并查看数据


# 导入相应库
import pandas as pd
from pandas import Series,DataFrame
import numpy as np
df = pd.read_csv("./adults.txt")
df.head()


image.png


该数据集包含14个特征:分别为age ;workclass ;final_weight ;education ;education_num ;marital_status ;occupation ;relationship ;race ;sex ;capital_gain ;capital_loss ;hours_per_week ;native_country


其中数据集最后一列:salary表示这个人的年收入


特征工程


分割特征与标签


# 特征数据
data = df.iloc[:,:-1].copy()
data.head()


image.png


# 标签数据
target = df[["salary"]].copy()
target.head()

image.png


对非数值特征进行量化


由于KNN算法只能对数值类型的值进行计算,因此需要对非数值特征进行量化处理


把字符串类型的特征属性进行量化


对workclass职业这一特征进行量化


# 查看总共有多少个职业
ws = data.workclass.unique()
ws
array(['State-gov', 'Self-emp-not-inc', 'Private', 'Federal-gov',
       'Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked'],
      dtype=object)


可以看出总共有9类职业:包括未知的“?”。下面我们使用0-8这9个数字,分别对9种职业进行编码


# 定义转化函数
def convert_ws(item):
    # np.argwhere函数会返回,相应职业对应的索引
    return np.argwhere(ws==item)[0,0]
# 将职业转化为职业列表中索引值
data.workclass = data.workclass.map(convert_ws)
# 查看职业转化后的数据
data.head()


image.png


np.argwhere函数会返回相应职业对应的索引, np.argwhere(ws==“?”)[0,0],返回值为5


对其他字符串特征属性进行量化


与上述职业量化过程相同


# 需要进行量化的属性
cols = ['education',"marital_status","occupation","relationship","race","sex","native_country"]
# 使用遍历的方式对各列属性进行量化
def convert_item(item):
    return np.argwhere(uni == item)[0,0]
for col in cols:
    uni = data[col].unique()
    data[col] = data[col].map(convert_item)
# 查看对所有列进行量化后的数据
data.head()


image.png


建模与评估


好了,以上我们已经将所有特征进行了量化处理,下面就可以使用KNN算法进行建模了


from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
# 创建模型
knn = KNeighborsClassifier(n_neighbors=8)
# 划分训练集与测试集
x_train,x_test,y_train,y_test = train_test_split(data,target,test_size=0.01)
# 对模型进行训练
knn.fit(x_train,y_train)
# 使用测试集查看模型的准确度
knn.score(x_test,y_test)


0.7822085889570553
• 1


模型优化


我们可以看到,如果不对上述所有的特征数值进行处理,直接使用KNN模型进行训练的话,模型的准确率仅为78%


下面我们对特征数据进行归一化处理,然后再使用KNN模型进行建模与测试,看看结果如何。

# 把所有的数据归一化
# 创建归一化函数
def func(x):
    return (x-min(x))/(max(x)-min(x))
# 对特征数据进行归一化处理
data[data.columns] = data[data.columns].transform(func)
data.head()

image.png


# 划分训练集与测试集
x_train,x_test,y_train,y_test = train_test_split(data,target,test_size=0.01)
# 创建模型
knn = KNeighborsClassifier(n_neighbors=8)
# 训练模型
knn.fit(x_train,y_train)
# 使用测试集查看模型的准确度
knn.score(x_test,y_test)


0.8374233128834356
• 1


我们可以发现,将所有数据进行归一化处理后,准确率从78%提升到了84%,还是比较不错的。


当然还有一些其他的处理方式对模型进行优化,后续博文会持续更新,欢迎关注。


总结


这篇文章主要介绍了以下几点内容:


  1. 如何对字符串类型的数据进行量化处理
  2. 使用KNN模型对人的年收入进行预测
  3. 模型优化:对数据进行归一化处理之后,有利于提高模型准确度。


相关文章
|
29天前
|
机器学习/深度学习 算法 数据挖掘
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构。本文介绍了K-means算法的基本原理,包括初始化、数据点分配与簇中心更新等步骤,以及如何在Python中实现该算法,最后讨论了其优缺点及应用场景。
95 4
|
8天前
|
算法
PAI下面的gbdt、xgboost、ps-smart 算法如何优化?
设置gbdt 、xgboost等算法的样本和特征的采样率
22 2
|
25天前
|
机器学习/深度学习 算法 数据挖掘
C语言在机器学习中的应用及其重要性。C语言以其高效性、灵活性和可移植性,适合开发高性能的机器学习算法,尤其在底层算法实现、嵌入式系统和高性能计算中表现突出
本文探讨了C语言在机器学习中的应用及其重要性。C语言以其高效性、灵活性和可移植性,适合开发高性能的机器学习算法,尤其在底层算法实现、嵌入式系统和高性能计算中表现突出。文章还介绍了C语言在知名机器学习库中的作用,以及与Python等语言结合使用的案例,展望了其未来发展的挑战与机遇。
43 1
|
1月前
|
机器学习/深度学习 自然语言处理 算法
深入理解机器学习算法:从线性回归到神经网络
深入理解机器学习算法:从线性回归到神经网络
|
1月前
|
机器学习/深度学习 算法
深入探索机器学习中的决策树算法
深入探索机器学习中的决策树算法
41 0
|
1天前
|
机器学习/深度学习 算法
基于改进遗传优化的BP神经网络金融序列预测算法matlab仿真
本项目基于改进遗传优化的BP神经网络进行金融序列预测,使用MATLAB2022A实现。通过对比BP神经网络、遗传优化BP神经网络及改进遗传优化BP神经网络,展示了三者的误差和预测曲线差异。核心程序结合遗传算法(GA)与BP神经网络,利用GA优化BP网络的初始权重和阈值,提高预测精度。GA通过选择、交叉、变异操作迭代优化,防止局部收敛,增强模型对金融市场复杂性和不确定性的适应能力。
102 80
|
20天前
|
算法
基于WOA算法的SVDD参数寻优matlab仿真
该程序利用鲸鱼优化算法(WOA)对支持向量数据描述(SVDD)模型的参数进行优化,以提高数据分类的准确性。通过MATLAB2022A实现,展示了不同信噪比(SNR)下模型的分类误差。WOA通过模拟鲸鱼捕食行为,动态调整SVDD参数,如惩罚因子C和核函数参数γ,以寻找最优参数组合,增强模型的鲁棒性和泛化能力。
|
26天前
|
机器学习/深度学习 算法 Serverless
基于WOA-SVM的乳腺癌数据分类识别算法matlab仿真,对比BP神经网络和SVM
本项目利用鲸鱼优化算法(WOA)优化支持向量机(SVM)参数,针对乳腺癌早期诊断问题,通过MATLAB 2022a实现。核心代码包括参数初始化、目标函数计算、位置更新等步骤,并附有详细中文注释及操作视频。实验结果显示,WOA-SVM在提高分类精度和泛化能力方面表现出色,为乳腺癌的早期诊断提供了有效的技术支持。
|
6天前
|
供应链 算法 调度
排队算法的matlab仿真,带GUI界面
该程序使用MATLAB 2022A版本实现排队算法的仿真,并带有GUI界面。程序支持单队列单服务台、单队列多服务台和多队列多服务台三种排队方式。核心函数`func_mms2`通过模拟到达时间和服务时间,计算阻塞率和利用率。排队论研究系统中顾客和服务台的交互行为,广泛应用于通信网络、生产调度和服务行业等领域,旨在优化系统性能,减少等待时间,提高资源利用率。
|
14天前
|
存储 算法
基于HMM隐马尔可夫模型的金融数据预测算法matlab仿真
本项目基于HMM模型实现金融数据预测,包括模型训练与预测两部分。在MATLAB2022A上运行,通过计算状态转移和观测概率预测未来值,并绘制了预测值、真实值及预测误差的对比图。HMM模型适用于金融市场的时间序列分析,能够有效捕捉隐藏状态及其转换规律,为金融预测提供有力工具。