机器学习之朴素贝叶斯分类

简介: 贝叶斯概率在机器学习、自然语言处理中被广泛地应用,对于海量数据的文本分类问题(比如垃圾邮件的甄选和过滤),基于贝叶思的算法取得非常好的效果。

贝叶斯概率在机器学习、自然语言处理中被广泛地应用,对于海量数据的文本分类问题(比如垃圾邮件的甄选和过滤),基于贝叶思的算法取得非常好的效果。

一、概率基础

概率:

概率是某一事件或者预测行为的可信程度。取值在0-1之间。

比如,抛一枚硬币,正面朝上的可能性和反面朝上的肯能性是相等的,都是0.5.

条件概率:

条件概率是指在某些前提条件的概率问题。

比如,根据美国疾病控制中心美国每年大约有78.5万人罹患心脏病,美国约有3.11亿人,那么随机挑选一个美国人,预测其在明年的心脏病发病率大约为78.5万/3.11亿,大约时0.3%。但是就个体而言,年轻孩子和患有高胆固醇的中年人患病的风险不一样,年龄、个人健康状况就是前提条件。

联合概率:

表示两个事件同时发生的概率。事件A和事件B同时发生的概率记做P(AB)或者P(A,B),或者P(A∩B)

贝叶斯概率:

显然P(AB)=P(BA),AB同时发生的概率等于事件A发生的概率乘以事件A发生的前提下事件B发生的概率

P(AB)=P(A)(B|A)

AB同时发生的概率也等于事件B发生的概率乘以事件B发生的前提下事件A发生的概率
P(AB)=P(B)P(A|B)

所以:

P(B)P(A|B)=P(A)(B|A)

等式两边同时除以P(B)

P(A|B)=P(B|A)P(A)P(B)

通过公式可以通俗地理解贝叶斯概率的思想分母P(B)是指事件B所发生的概率,也就是总概率,分子是AB同时发生的概率,用AB同事发生的概率除以B发生的概率就可以得出B发生是来自事件A的概率.

二、朴素贝叶斯分类

2.1贝叶斯决策理论

贝叶斯决策理论的核心思想是选择高概率对的类别。比如有一个数据集由2类数据组成,数据分布如下图:
这里写图片描述
假设数据集的统计参数已知,用p1(x,y)表示点(x,y)属于类别1,用p2(x,y)表示点(x,y)属于类别2,那么对于一个新的数据点(x,y),如果p1(x,y)>p2(x,y)那么点属于类别1,反之属于类别2.

2.2实战题目

题目:
给定一个训练集和测试集,trainData.txttestData.txt,每一行的第一列代表分类C,第2、3、4、5、6列分别对应特征A1、A2、A3、A4、A5,特征A1到A5相互独立。根据训练集的分类训练分类器,验证测试集中分类的正确率。

2.3分析与思考

对于测试集中的每一条记录,先根据特征分别计算P(0|A)和P(1|A),A是一个5维向量,分别代表A1到A5的取值。那么根据贝叶斯公式:

P(C=0|A)=P(A|C=0)P(C=0)P(A),P(C=1|A)=P(A|C=1)P(C=1)P(A)

因为只需要比较P(0|A)和P(1|A)哪个概率大,也就是由向量A的特征决定的分类是0的概率大还是1的概率大,所以分母中的P(A)无需计算,只需要计算
P(0|A)=P(A|0)P(0),P(1|A)=P(A|1)P(1)
然后比较P(0|A)和P(1|A)的概率大小。计算P(0)和P(1)很容易计算,分别统计训练集中分类为0和分类为1的记录所占的百分比即可,也就是trainData.txt第一列中为0的总数和第一列为1的总数所占的比例。然后需要求P(A|0)和P(A|1),因为特征独立:
P(A|C=0)=P((A1,A2,A3,A4,A5)|C=0)=i=15P(Ai|C=0)

同样:

P(A|C=1)=P((A1,A2,A3,A4,A5)|C=1)=i=15P(Ai|C=1)

然后分别计算
P(Ai|C),C01Ai15,P(A0|C=0),trainData.txt00A0=0
为P(A_0=0|C=0),A_0=1所占的比例即为P(A_0=1|C=0)$$.

2.4参考代码

package big.data.ml;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;

public class Bayes {
    private static int Count = 0; // 总的训练样本数
    private static int Count0 = 0; // 训练样本中分类为0的个数
    private static int Count1 = 0; // 训练样本中分类为1的个数
    private static double[] pc = new double[2];
    private static int[][] matrixA = new int[32][6];
    private static File traningData = new File("data/trainingData.txt");
    private static File testingData = new File("data/testingData.txt");
    private static int[][] trainData, testData;

    public static int[][] getTrainData() {
        return trainData;
    }

    public static void setTrainData(int[][] trainData) {
        Bayes.trainData = trainData;
    }

    public static int[][] getTestData() {
        return testData;
    }

    public static void setTestData(int[][] testData) {
        Bayes.testData = testData;
    }

    public static void main(String[] args) {

        int[][] trainData, testData;
        trainData = getData(traningData, 0);
        testData = getData(testingData, 1);
        setTrainData(trainData);
        setTestData(testData);
        for (int i = 0; i < trainData.length; i++) {
            Count++;
            if (trainData[i][0] == 0) {
                Count0++;
            } else if (trainData[i][0] == 1) {
                Count1++;
            }

        }
        pc[0] = Count0 / (double) Count;
        pc[1] = Count1 / (double) Count;
        System.out.println("训练样本中分类为0的概率" + pc[0]);
        System.out.println("训练样本中分类为1的概率" + pc[1]);
        int[] Ai = new int[5];

        int rightClassify = 0;
        double pc0 = 1, pc1 = 1;
        for (int i = 0; i < testData.length; i++) {
            pc0 = pc[0];
            pc1 = pc[1];
            System.arraycopy(testData[i], 1, Ai, 0, 5);
            for (int j = 0; j < Ai.length; j++) {

                pc0 *= calPai(Ai[j], j + 1, 0);
                pc1 *= calPai(Ai[j], j + 1, 1);

            }

            if (pc0 > pc1 && testData[i][0] == 0) {
                rightClassify++;
            } else if (pc0 < pc1 && testData[i][0] == 1) {
                rightClassify++;
            }
        }

        for (int index = 0; index < 5; index++) {
            for (int i = 0; i < 2; i++) {
                for (int c = 0; c < 2; c++) {
                    System.out.print("P(A"+index+"=" + i + "|c=" + c + ")" + "="
                            + (double) Math.round(calPai(i, index + 1, c) * 100) / 100 + "    ");
                }

            }
            System.out.println("\n");
        }

        System.out.println("正确率:" + rightClassify / (double) testData.length);
        System.out.println("错误率:" + (1 - rightClassify / (double) testData.length));

    }

    // 训练集和测试集转化为二维数组
    public static int[][] getData(File file, int c) {
        int[][] dataArr;
        if (c == 0) {
            dataArr = new int[100][6];
        } else if (c == 1) {
            dataArr = new int[200][6];
        } else {
            System.out.println("参数错误!");
            return null;
        }

        int[] a;
        int m = 0;
        if (file.exists()) {
            System.out.println("成功加载数据!");
            try {
                FileInputStream fis = new FileInputStream(file);
                InputStreamReader isr = new InputStreamReader(fis);
                BufferedReader bfr = new BufferedReader(isr);
                String line = "";
                while ((line = bfr.readLine()) != null) {
                    String[] arr = line.split("\\s+");
                    a = new int[6];
                    for (int i = 0; i < a.length; i++) {
                        a[i] = Integer.parseInt(arr[i]);
                    }
                    dataArr[m++] = a;
                }

                return dataArr;
            } catch (FileNotFoundException e) {
                e.printStackTrace();
            } catch (IOException e) {
                e.printStackTrace();
            }
        } else {
            System.out.println("加载训练数据集失败!");

        }
        return null;
    }

    // 打印二维数组
    public static void printMatrix(int[][] A) {
        for (int i = 0; i < A.length; i++) {
            for (int j = 0; j < A[0].length; j++) {
                System.out.print(A[i][j] + " ");
            }
            System.out.println("\n");
        }
    }

    // 训练集上计算p(Ai|c)
    public static double calPai(int Ai, int i, int c) {
        int[][] trainData;
        trainData = getTrainData();

        int countI = 0, allCount = 0;
        for (int j = 0; j < trainData.length; j++) {
            if (trainData[j][0] == c) {
                allCount++;
            }
            if ((trainData[j][0] == c) && (trainData[j][i] == Ai)) {
                countI++;
            }
        }

        return countI / (double) allCount;
    }

}

输出:

成功加载数据!
成功加载数据!
训练样本中分类为0的概率0.34
训练样本中分类为1的概率0.66
P(A0=0|c=0)=0.71    P(A0=0|c=1)=0.29    P(A0=1|c=0)=0.29    P(A0=1|c=1)=0.71    

P(A1=0|c=0)=0.62    P(A1=0|c=1)=0.52    P(A1=1|c=0)=0.38    P(A1=1|c=1)=0.48    

P(A2=0|c=0)=0.65    P(A2=0|c=1)=0.56    P(A2=1|c=0)=0.35    P(A2=1|c=1)=0.44    

P(A3=0|c=0)=0.56    P(A3=0|c=1)=0.53    P(A3=1|c=0)=0.44    P(A3=1|c=1)=0.47    

P(A4=0|c=0)=0.59    P(A4=0|c=1)=0.55    P(A4=1|c=0)=0.41    P(A4=1|c=1)=0.45    

正确率:0.625
错误率:0.375

参考文献

  1. 贝叶斯分类
  2. 《机器学习实战》
  3. 《贝叶斯思维:统计建模的Python学习法》
目录
相关文章
|
2月前
|
机器学习/深度学习
如何用贝叶斯方法来解决机器学习中的分类问题?
【10月更文挑战第5天】如何用贝叶斯方法来解决机器学习中的分类问题?
|
2月前
|
机器学习/深度学习 存储 自然语言处理
【机器学习】基于逻辑回归的分类预测
【机器学习】基于逻辑回归的分类预测
|
2月前
|
机器学习/深度学习 传感器 算法
机器学习入门(一):机器学习分类 | 监督学习 强化学习概念
机器学习入门(一):机器学习分类 | 监督学习 强化学习概念
|
2月前
|
机器学习/深度学习 算法 数据可视化
机器学习的核心功能:分类、回归、聚类与降维
机器学习领域的基本功能类型通常按照学习模式、预测目标和算法适用性来分类。这些类型包括监督学习、无监督学习、半监督学习和强化学习。
47 0
|
2月前
|
机器学习/深度学习 程序员
【机器学习】朴素贝叶斯原理------迅速了解常见概率的计算
【机器学习】朴素贝叶斯原理------迅速了解常见概率的计算
|
4月前
|
机器学习/深度学习 人工智能 算法
【人工智能】机器学习、分类问题和逻辑回归的基本概念、步骤、特点以及多分类问题的处理方法
机器学习是人工智能的一个核心分支,它专注于开发算法,使计算机系统能够自动地从数据中学习并改进其性能,而无需进行明确的编程。这些算法能够识别数据中的模式,并利用这些模式来做出预测或决策。机器学习的主要应用领域包括自然语言处理、计算机视觉、推荐系统、金融预测、医疗诊断等。
83 1
|
4月前
|
机器学习/深度学习 算法
【机器学习】简单解释贝叶斯公式和朴素贝叶斯分类?(面试回答)
简要解释了贝叶斯公式及其在朴素贝叶斯分类算法中的应用,包括算法的基本原理和步骤。
82 1
|
4月前
|
机器学习/深度学习
如何用贝叶斯方法来解决机器学习中的分类问题?
如何用贝叶斯方法来解决机器学习中的分类问题?
|
4月前
|
机器学习/深度学习 数据采集 自然语言处理
【NLP】讯飞英文学术论文分类挑战赛Top10开源多方案–4 机器学习LGB 方案
在讯飞英文学术论文分类挑战赛中使用LightGBM模型进行文本分类的方案,包括数据预处理、特征提取、模型训练及多折交叉验证等步骤,并提供了相关的代码实现。
53 0
|
6月前
|
机器学习/深度学习 算法
机器学习方法分类
【6月更文挑战第14天】机器学习方法分类。
125 2
下一篇
DataWorks