JavaScript机器学习之KNN算法

简介: 译者按 机器学习原来很简单啊,不妨动手试试!原文: Machine Learning with JavaScript : Part 2译者: Fundebug为了保证可读性,本文采用意译而非直译。

译者按 机器学习原来很简单啊,不妨动手试试!

原文: Machine Learning with JavaScript : Part 2

译者: Fundebug

为了保证可读性,本文采用意译而非直译。另外,本文版权归原作者所有,翻译仅用于学习。另外,我们修正了原文代码中的错误

img_d575a0346f2034f12b20c9daab561852.png

上图使用plot.ly所画。

上次我们用JavaScript实现了线性规划,这次我们来聊聊KNN算法。

KNN是****k-Nearest-Neighbours****的缩写,它是一种监督学习算法。KNN算法可以用来做分类,也可以用来解决回归问题。

GitHub仓库: machine-learning-with-js

KNN算法简介

简单地说,KNN算法由那离自己最近的K个点来投票决定待分类数据归为哪一类

如果待分类的数据有这些邻近数据,NY: ****7****, NJ: ****0****, IN: ****4****,即它有7个****NY****邻居,0个****NJ****邻居,4个****IN****邻居,则这个数据应该归类为****NY****。

假设你在邮局工作,你的任务是为邮递员分配信件,目标是最小化到各个社区的投递旅程。不妨假设一共有7个街区。这就是一个实际的分类问题。你需要将这些信件分类,决定它属于哪个社区,比如上东城曼哈顿下城等。

最坏的方案是随意分配信件分配给邮递员,这样每个邮递员会拿到各个社区的信件。

最佳的方案是根据信件地址进行分类,这样每个邮递员只需要负责邻近社区的信件。

也许你是这样想的:"将邻近3个街区的信件分配给同一个邮递员"。这时,邻近街区的个数就是****k****。你可以不断增加****k****,直到获得最佳的分配方案。这个****k****就是分类问题的最佳值。

KNN代码实现

上次一样,我们将使用mljs的****KNN****模块ml-knn来实现。

每一个机器学习算法都需要数据,这次我将使用****IRIS数据集****。其数据集包含了150个样本,都属于鸢尾属下的三个亚属,分别是山鸢尾变色鸢尾维吉尼亚鸢尾。四个特征被用作样本的定量分析,它们分别是花萼花瓣的长度和宽度。

1. 安装模块

$ npm install ml-knn@2.0.0 csvtojson prompt

ml-knn: ****k-Nearest-Neighbours****模块,不同版本的接口可能不同,这篇博客使用了2.0.0

csvtojson: 用于将CSV数据转换为JSON

prompt: 在控制台输入输出数据

2. 初始化并导入数据

IRIS数据集由加州大学欧文分校提供。

curl https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data > iris.csv

假设你已经初始化了一个NPM项目,请在****index.js****中输入以下内容:

const KNN = require('ml-knn');
const csv = require('csvtojson');
const prompt = require('prompt');

var knn;

const csvFilePath = 'iris.csv'; // 数据集
const names = ['sepalLength', 'sepalWidth', 'petalLength', 'petalWidth', 'type'];

let seperationSize; // 分割训练和测试数据

let data = [],
    X = [],
    y = [];

let trainingSetX = [],
    trainingSetY = [],
    testSetX = [],
    testSetY = [];
  • ****seperationSize****用于分割数据和测试数据

使用csvtojson模块的fromFile方法加载数据:

csv(
    {
        noheader: true,
        headers: names
    })
    .fromFile(csvFilePath)
    .on('json', (jsonObj) =>
    {
        data.push(jsonObj); // 将数据集转换为JS对象数组
    })
    .on('done', (error) =>
    {
        seperationSize = 0.7 * data.length;
        data = shuffleArray(data);
        dressData();
    });

我们将****seperationSize****设为样本数目的0.7倍。注意,如果训练数据集太小的话,分类效果将变差。

由于数据集是根据种类排序的,所以需要使用****shuffleArray****函数对数据进行混淆,这样才能方便分割出训练数据。这个函数的定义请参考StackOverflow的提问How to randomize (shuffle) a JavaScript array?:

function shuffleArray(array)
{
    for (var i = array.length - 1; i > 0; i--)
    {
        var j = Math.floor(Math.random() * (i + 1));
        var temp = array[i];
        array[i] = array[j];
        array[j] = temp;
    }
    return array;
}

3. 转换数据

数据集中每一条数据可以转换为一个JS对象:

{
 sepalLength: ‘5.1’,
 sepalWidth: ‘3.5’,
 petalLength: ‘1.4’,
 petalWidth: ‘0.2’,
 type: ‘Iris-setosa’ 
}

在使用****KNN****算法训练数据之前,需要对数据进行这些处理:

  1. 将属性(sepalLength, sepalWidth,petalLength,petalWidth)由字符串转换为浮点数. (****parseFloat****)
  2. 将分类 (type)用数字表示
function dressData()
{
    let types = new Set(); 
    data.forEach((row) =>
    {
        types.add(row.type);
    });
    let typesArray = [...types]; 

    data.forEach((row) =>
    {
        let rowArray, typeNumber;
        rowArray = Object.keys(row).map(key => parseFloat(row[key])).slice(0, 4);
        typeNumber = typesArray.indexOf(row.type); // Convert type(String) to type(Number)

        X.push(rowArray);
        y.push(typeNumber);
    });

    trainingSetX = X.slice(0, seperationSize);
    trainingSetY = y.slice(0, seperationSize);
    testSetX = X.slice(seperationSize);
    testSetY = y.slice(seperationSize);

    train();
}

4. 训练数据并测试

function train()
{
    knn = new KNN(trainingSetX, trainingSetY,
    {
        k: 7
    });
    test();
}

****train****方法需要2个必须的参数: 输入数据,即花萼花瓣的长度和宽度;实际分类,即山鸢尾变色鸢尾维吉尼亚鸢尾。另外,第三个参数是可选的,用于提供调整****KNN****算法的内部参数。我将****k****参数设为7,其默认值为5。

训练好模型之后,就可以使用测试数据来检查准确性了。我们主要对预测出错的个数比较感兴趣。

function test()
{
    const result = knn.predict(testSetX);
    const testSetLength = testSetX.length;
    const predictionError = error(result, testSetY);
    console.log(`Test Set Size = ${testSetLength} and number of Misclassifications = ${predictionError}`);
    predict();
}

比较预测值与真实值,就可以得到出错个数:

function error(predicted, expected)
{
    let misclassifications = 0;
    for (var index = 0; index < predicted.length; index++)
    {
        if (predicted[index] !== expected[index])
        {
            misclassifications++;
        }
    }
    return misclassifications;
}

5. 进行预测(可选)

任意输入属性值,就可以得到预测值

function predict()
{
    let temp = [];
    prompt.start();
    prompt.get(['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width'], function(err, result)
    {
        if (!err)
        {
            for (var key in result)
            {
                temp.push(parseFloat(result[key]));
            }
            console.log(`With ${temp} -- type =  ${knn.predict(temp)}`);
        }
    });
}

6. 完整程序

完整的程序****index.js****是这样的:

const KNN = require('ml-knn');
const csv = require('csvtojson');
const prompt = require('prompt');

var knn;

const csvFilePath = 'iris.csv'; // 数据集
const names = ['sepalLength', 'sepalWidth', 'petalLength', 'petalWidth', 'type'];

let seperationSize; // 分割训练和测试数据

let data = [],
    X = [],
    y = [];

let trainingSetX = [],
    trainingSetY = [],
    testSetX = [],
    testSetY = [];

csv(
    {
        noheader: true,
        headers: names
    })
    .fromFile(csvFilePath)
    .on('json', (jsonObj) =>
    {
        data.push(jsonObj); // 将数据集转换为JS对象数组
    })
    .on('done', (error) =>
    {
        seperationSize = 0.7 * data.length;
        data = shuffleArray(data);
        dressData();
    });

function dressData()
{
    let types = new Set(); 
    data.forEach((row) =>
    {
        types.add(row.type);
    });
    let typesArray = [...types]; 

    data.forEach((row) =>
    {
        let rowArray, typeNumber;
        rowArray = Object.keys(row).map(key => parseFloat(row[key])).slice(0, 4);
        typeNumber = typesArray.indexOf(row.type); // Convert type(String) to type(Number)

        X.push(rowArray);
        y.push(typeNumber);
    });

    trainingSetX = X.slice(0, seperationSize);
    trainingSetY = y.slice(0, seperationSize);
    testSetX = X.slice(seperationSize);
    testSetY = y.slice(seperationSize);

    train();
}

// 使用KNN算法训练数据
function train()
{
    knn = new KNN(trainingSetX, trainingSetY,
    {
        k: 7
    });
    test();
}

// 测试训练的模型
function test()
{
    const result = knn.predict(testSetX);
    const testSetLength = testSetX.length;
    const predictionError = error(result, testSetY);
    console.log(`Test Set Size = ${testSetLength} and number of Misclassifications = ${predictionError}`);
    predict();
}

// 计算出错个数
function error(predicted, expected)
{
    let misclassifications = 0;
    for (var index = 0; index < predicted.length; index++)
    {
        if (predicted[index] !== expected[index])
        {
            misclassifications++;
        }
    }
    return misclassifications;
}

// 根据输入预测结果
function predict()
{
    let temp = [];
    prompt.start();
    prompt.get(['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width'], function(err, result)
    {
        if (!err)
        {
            for (var key in result)
            {
                temp.push(parseFloat(result[key]));
            }
            console.log(`With ${temp} -- type =  ${knn.predict(temp)}`);
        }
    });
}

// 混淆数据集的顺序
function shuffleArray(array)
{
    for (var i = array.length - 1; i > 0; i--)
    {
        var j = Math.floor(Math.random() * (i + 1));
        var temp = array[i];
        array[i] = array[j];
        array[j] = temp;
    }
    return array;
}

在控制台执行****node index.js****

$ node index.js

输出如下:

Test Set Size = 45 and number of Misclassifications = 2
prompt: Sepal Length:  1.7
prompt: Sepal Width:  2.5
prompt: Petal Length:  0.5
prompt: Petal Width:  3.4
With 1.7,2.5,0.5,3.4 -- type =  2

参考链接

欢迎加入我们Fundebug全栈BUG监控交流群: 622902485

img_ea207c7ffd843f2e46ed4001365c97a8.jpe

版权声明:
转载时请注明作者Fundebug以及本文地址:**
https://blog.fundebug.com/2017/07/10/javascript-machine-learning-knn/**

目录
相关文章
|
22天前
|
机器学习/深度学习 算法 搜索推荐
Machine Learning机器学习之决策树算法 Decision Tree(附Python代码)
Machine Learning机器学习之决策树算法 Decision Tree(附Python代码)
|
15天前
|
机器学习/深度学习 自然语言处理 算法
|
2天前
|
机器学习/深度学习 算法 搜索推荐
Python用机器学习算法进行因果推断与增量、增益模型Uplift Modeling智能营销模型
Python用机器学习算法进行因果推断与增量、增益模型Uplift Modeling智能营销模型
30 12
|
1月前
|
机器学习/深度学习 分布式计算 算法
大模型开发:你如何确定使用哪种机器学习算法?
在大型机器学习模型开发中,选择算法是关键。首先,明确问题类型(如回归、分类、聚类等)。其次,考虑数据规模、特征数量和类型、分布和结构,以判断适合的算法。再者,评估性能要求(准确性、速度、可解释性)和资源限制(计算资源、内存)。同时,利用领域知识和正则化来选择模型。最后,通过实验验证和模型比较进行优化。此过程涉及迭代和业务需求的技术权衡。
|
1月前
|
机器学习/深度学习 数据采集 算法
构建高效机器学习模型:从数据处理到算法优化
【2月更文挑战第30天】 在数据驱动的时代,构建一个高效的机器学习模型是实现智能决策和预测的关键。本文将深入探讨如何通过有效的数据处理策略、合理的特征工程、选择适宜的学习算法以及进行细致的参数调优来提升模型性能。我们将剖析标准化与归一化的差异,探索主成分分析(PCA)的降维魔力,讨论支持向量机(SVM)和随机森林等算法的适用场景,并最终通过网格搜索(GridSearchCV)来实现参数的最优化。本文旨在为读者提供一条清晰的路径,以应对机器学习项目中的挑战,从而在实际应用中取得更精准的预测结果和更强的泛化能力。
|
1月前
|
机器学习/深度学习 自然语言处理 算法
【机器学习】包裹式特征选择之拉斯维加斯包装器(LVW)算法
【机器学习】包裹式特征选择之拉斯维加斯包装器(LVW)算法
57 0
|
1月前
|
机器学习/深度学习 存储 搜索推荐
利用机器学习算法改善电商推荐系统的效率
电商行业日益竞争激烈,提升用户体验成为关键。本文将探讨如何利用机器学习算法优化电商推荐系统,通过分析用户行为数据和商品信息,实现个性化推荐,从而提高推荐效率和准确性。
|
1月前
|
机器学习/深度学习 算法 数据可视化
实现机器学习算法时,特征选择是非常重要的一步,你有哪些推荐的方法?
实现机器学习算法时,特征选择是非常重要的一步,你有哪些推荐的方法?
27 1
|
1月前
|
机器学习/深度学习 数据采集 算法
解码癌症预测的密码:可解释性机器学习算法SHAP揭示XGBoost模型的预测机制
解码癌症预测的密码:可解释性机器学习算法SHAP揭示XGBoost模型的预测机制
112 0
|
1月前
|
机器学习/深度学习 数据采集 监控
机器学习-特征选择:如何使用递归特征消除算法自动筛选出最优特征?
机器学习-特征选择:如何使用递归特征消除算法自动筛选出最优特征?
71 0