基于线性SVM的CIFAR-10图像集分类

简介: 基于线性SVM的CIFAR-10图像集分类

之前我用了六篇文章来详细介绍了支持向量机SVM的算法理论和模型,链接如下:


1. 线性支持向量机LSVM

2. 对偶支持向量机DSVM

3. 核支持向量机KSVM

4. 软间隔支持向量机

5. 核逻辑回归KLR

6. 支持向量回归SVR


实际上,支持向量机SVM确实是机器学习中一个非常重要也是非常复杂的模型。关于SVM的详细理论和推导,本文不再阐述,读者可以直接阅读上面的六篇文章。


学习完了复杂的理论知识,很多朋友可能非常想通过一个实际的例子,动手编写出一个SVM程序,应用到实际中。那么本文就将带领大家动手写出自己的SVM程序,并且应用到图像的分类问题中。我们将在经典的CIFAR10图像数据集上进行SVM程序验证。


话不多说,正式开始!


1   SVM的基本思想


简单来说,支持向量机SVM就是在特征空间中找到一条最佳的分类超平面,能够让正、负样本距离该超平面的间隔(margin)最大化。


以二维平面为例,确定一条直线对正负样本进行分类,如下图所示:



image.png


很明显,虽然分类线H1、H2、H3都能够将正负样本完全分开,但是毫无疑问H3更好一些。原因是正负样本距离H3都足够远,即间隔「margin」最大。这就是SVM的基本思想:尽量让所有样本距离分类超平面越远越好。

2   线性分类与得分函数


在线性分类器算法中,输入为x,输出为y,令权重系数为W,常数项系数为b。我们定义得分函数s为:


image.png


这是线性分类器的一般形式,得分函数s所属类别值越大,表示预测该类别的概率越大。


以图像识别为例,共有3个类别「cat,dog,ship」。令输入x的特征维度为4「即包含4个像素值」,W的维度是3x4,b的维度是3x1。在W和b确定后,得到各个类别的得分函数s为:

image.png


由上图可知,因为总有3个类别,得分函数s是3x1的向量。其中,cat score=-96.8,dog score=437.9,ship score=61.95。从s的值来说,dog score最高,cat score最低,则预测为狗的概率更大一些。而该图片真实标签是一只猫,显然,从得分函数s上来看,该线性分类器的预测结果是错误的。


通常为了简化计算,我们直接将W和b整合成一个矩阵,同时将x额外增加一个全为1的维度。这样,得分函数s的表达式得到了简化:


image.png


示例图如下:


image.png


3  优化策略与损失函数


通常来说,SVM的优化策略是样本到分类超平面的距离最大化。也就是说尽量让正负样本距离分类超平面有足够宽的间隔,这是基于距离的衡量优化方式。针对上文提到的例子,图片真实标签是一只猫,但是得到的s值却是最低的,显然这不是我们希望看到的。最好的情况应该是cat score最高。这样才能保证预测cat的概率更大。此时,利用SVM的间隔最大化的思想,就要求cat score不仅仅要大于其它类别的s值,而且要达到一定的程度,可以说有个最低阈值。


因此,这种新的SVM优化策略可以这样理解:正确类别对应的得分函数s应该比其它类别的得分函数s大一个阈值 Δ:

image.png

image.png


其中,y
i表示正确的类别,j表示错误类别。从Li的表达式可以看出,只有当syi比sj大超过阈值 Δ 时,Li才为零,否则Li大于零。这种策略类似于距离最大化策略。


举个例子来解释Li的计算过程:例如得分函数s=[-1, 5, 4],y1是真实样本,令Δ=3,则:



image.png


该损失函数由两部分组成:y1与y0,y1与y2。由于y1与y0的差值大于阈值 Δ,则其损失函数为0;虽然y1比y2大,但差值小于阈值 Δ,则计算得到其损失函数为2。总的损失函数即为2。


这类损失函数的表达式一般称作合页损失函数Hinge Loss Function


image.png


显然,只有当s
j - syi + Δ < 0 时,损失函数才为零。


这种合页损失函数的优点是体现了SVM距离最大化的思想;而且,损失函数大于零时,是线性函数,便于梯度下降算法求导。


除了这种线性hinge loss SVM之外,还有squared hinge loss SVM,即采用平方的形式:


image.png


这种squared hinge loss SVM与linear hinge loss SVM相比较,特点是对违背间隔阈值要求的点加重惩罚,违背的越大,惩罚越大。某些实际应用中,squared hinge loss SVM的效果更好一些。具体使用哪个,可以根据实际问题,进行交叉验证再确定。


对于超参数阈值 Δ,一般设置 Δ = 1。因为,权重系数W是可伸缩的,直接影响着得分函数s的大小。所以说,Δ = 1 或 Δ = 10,实际上没有差别,对W的伸缩完全可以抵消掉 Δ 的数值影响。因此,通常把 Δ 设置为1即可。此时的损失函数为:

image.png

image.png

其中,N是训练样本个数,λ 是正则化参数,可调。一般来说,λ 越大,对权重W的惩罚越大;λ 越小,对权重W的惩罚越小。λ 实际上是权衡损失函数第一项和第二项之间的关系:λ 越大,对W的惩罚更大,牺牲正负样本之间的间隔,可能造成欠拟合「underfit」;λ 越小,得到的正负样本间隔更大,但是W数值会变大,可能造成过拟合「overfit」。实际应用中,可通过交叉验证,选择合适的正则化参数λ。


常数项b是否需要正则化?其实一般b是否正则化对模型的影响很小。可以对b进行正则化,也可以选择不。实际应用中,通常只对权重系数W进行正则化。


4  线性SVM实战


首先,简单介绍一下我们将要用到的经典数据集:CIFAR-10。


CIFAR-10数据集由60000张3×32×32的 RGB 彩色图片构成,共10个分类。50000张训练,10000张测试(交叉验证)。这个数据集最大的特点在于将识别迁移到了普适物体,而且应用于多分类,是非常经典和常用的数据集。


image.png


这个数据集网上可以下载,我直接给大家下好了,放在云盘里,需要的自行领取。


链接:

https://pan.baidu.com/s/1iZPwt72j-EpVUbLKgEpYMQ

密码:vy1e


下面的代码是随机选择每种类别下的5张图片并显示:


# Visualize some examples from the dataset.
# We show a few examples of training images from each class.
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
num_classes = len(classes)
samples_per_class = 7
for y, cls in enumerate(classes):
   idxs = np.flatnonzero(y_train == y)
   idxs = np.random.choice(idxs, samples_per_class, replace=False)
   for i, idx in enumerate(idxs):
       plt_idx = i * num_classes + y + 1
       plt.subplot(samples_per_class, num_classes, plt_idx)
       plt.imshow(X_train[idx].astype('uint8'))
       plt.axis('off')
       if i == 0:
           plt.title(cls)
plt.show()

image.png

接下来,就是对SVM计算hinge loss,包含L2正则化,代码如下:


scores = X.dot(W)
correct_class_score = scores[range(num_train), list(y)].reshape(-1,1) # (N,1)
margin = np.maximum(0, scores - correct_class_score + 1)
margin[range(num_train), list(y)] = 0
loss = np.sum(margin) / num_train + 0.5 * reg * np.sum(W * W)


计算W梯度的代码如下:


num_classes = W.shape[1]
inter_mat = np.zeros((num_train, num_classes))
inter_mat[margin > 0] = 1
inter_mat[range(num_train), list(y)] = 0
inter_mat[range(num_train), list(y)] = -np.sum(inter_mat, axis=1)
dW = (X.T).dot(inter_mat)
dW = dW/num_train + reg*W


根据SGD算法,每次迭代后更新W:


W -=  learning_rate * dW


训练过程中,使用交叉验证的方法选择最佳的学习因子 learning_rate 和正则化参数 reg,代码如下:


learning_rates = [1.4e-7, 1.5e-7, 1.6e-7]
regularization_strengths = [8000.0, 9000.0, 10000.0, 11000.0, 18000.0, 19000.0, 20000.0, 21000.0]
results = {}
best_lr = None
best_reg = None
best_val = -1   # The highest validation accuracy that we have seen so far.
best_svm = None # The LinearSVM object that achieved the highest validation rate.
for lr in learning_rates:
   for reg in regularization_strengths:
       svm = LinearSVM()
       loss_history = svm.train(X_train, y_train, learning_rate = lr, reg = reg, num_iters = 2000)
       y_train_pred = svm.predict(X_train)
       accuracy_train = np.mean(y_train_pred == y_train)
       y_val_pred = svm.predict(X_val)
       accuracy_val = np.mean(y_val_pred == y_val)
       if accuracy_val > best_val:
           best_lr = lr
           best_reg = reg
           best_val = accuracy_val
           best_svm = svm
       results[(lr, reg)] = accuracy_train, accuracy_val
       print('lr: %e reg: %e train accuracy: %f val accuracy: %f' %
             (lr, reg, results[(lr, reg)][0], results[(lr, reg)][1]))
print('Best validation accuracy during cross-validation:\nlr = %e, reg = %e, best_val = %f' %
     (best_lr, best_reg, best_val))


训练结束后,选择最佳的学习因子 learning_rate 和正则化参数 reg,在测试图片集上进行验证,代码如下:


# Evaluate the best svm on test set
y_test_pred = best_svm.predict(X_test)
test_accuracy = np.mean(y_test == y_test_pred)
print('linear SVM on raw pixels final test set accuracy: %f' % test_accuracy)


>> linear SVM on raw pixels final test set accuracy: 0.384000


最后,有个比较好玩的操作,我们可以将训练好的权重W可视化:


# Visualize the learned weights for each class.
# Depending on your choice of learning rate and regularization strength, these may
# or may not be nice to look at.
w = best_svm.W[:-1,:] # strip out the bias
w = w.reshape(32, 32, 3, 10)
w_min, w_max = np.min(w), np.max(w)
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
for i in range(10):
   plt.subplot(2, 5, i + 1)
   # Rescale the weights to be between 0 and 255
   wimg = 255.0 * (w[:, :, :, i].squeeze() - w_min) / (w_max - w_min)
   plt.imshow(wimg.astype('uint8'))
   plt.axis('off')
   plt.title(classes[i])


image.png

5  总结


本文讲述的线性SVM利用距离间隔最大的思想,利用hinge loss的优化策略,来构建一个机器学习模型,并将这个简单模型应用到CIFAR-10图片集中进行训练和测试。实际测试的准确率在40%左右。准确率虽然不是很高,但是此SVM是线性模型,没有引入核函数构建非线性模型,也没有使用AlexNet,VGG,GoogLeNet,ResNet等卷积网络。测试结果比随机猜测10%要好很多,是一个不错的可实操的有趣模型。



相关文章
|
Web App开发 数据采集 JSON
Python实现urllib3和requests库使用 | python爬虫实战之五
本节介绍了urllib3库和requests库中的一些方法的使用。
Python实现urllib3和requests库使用 | python爬虫实战之五
|
数据可视化 前端开发 JavaScript
pyEcharts安装及详细使用指南(一)
pyEcharts安装及详细使用指南(一)
1923 0
pyEcharts安装及详细使用指南(一)
|
7月前
|
API
零门槛,体验DeepSeek-R1满血版
DeepSeek是当前性能最高、最受欢迎的大语言模型之一,但由于访问量大,官方服务响应较慢。阿里云百炼平台提供了高效响应的满血版DeepSeek R1,用户可通过ChatBox轻松接入体验。首先,需用阿里云账号登录并创建APIKEY,接着下载并安装ChatBox,按照指引操作即可畅享DeepSeek的强大功能。感谢阿里云的努力,期待更多优秀模型的接入。
369 9
|
机器学习/深度学习 数据采集 数据可视化
深度学习实践:构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类
本文详细介绍如何使用PyTorch构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行图像分类。从数据预处理、模型定义到训练过程及结果可视化,文章全面展示了深度学习项目的全流程。通过实际操作,读者可以深入了解CNN在图像分类任务中的应用,并掌握PyTorch的基本使用方法。希望本文为您的深度学习项目提供有价值的参考与启示。
|
自然语言处理 数据处理
情感分析的终极形态:全景式细粒度多模态对话情感分析基准PanoSent
【9月更文挑战第24天】PanoSent是一种全新的多模态对话情感分析框架,旨在全景式地提取和分析情感元素,包括情感六元组提取与情感翻转分析两大任务。此框架依托大规模、高质量的多模态数据集PanoSent,涵盖文本、图像、音频等多种模态及多种语言,适应不同应用场景。为解决这些任务,研究人员提出了Chain-of-Sentiment推理框架,结合多模态大语言模型Sentica,实现细粒度的情感分析。尽管PanoSent在情感分析任务上表现优异,但仍面临多模态数据处理和跨领域适用性的挑战。
344 2
|
机器学习/深度学习 算法 TensorFlow
【深度学习】深度学习语音识别算法的详细解析
深度学习语音识别算法是一种基于人工神经网络的语音识别技术,其核心在于利用深度神经网络(Deep Neural Network,DNN)自动从语音信号中学习有意义的特征,并生成高效的语音识别模型。以下是对深度学习语音识别算法的详细解析
585 5
|
11月前
|
人工智能 搜索推荐 安全
盘点几款AI 赋能的 CRM 系统
在数字化时代,客户关系管理系统(CRM)成为企业提升竞争力、优化销售及增强客户满意度的关键工具。尤其随着人工智能(AI)技术的发展,AI功能强大的CRM系统为企业带来了前所未有的机遇。未来CRM系统将更加智能化、个性化,深度融合大数据、物联网等技术,并加强数据安全;典型如销售易CRM、Salesforce、Zoho CRM和HubSpot CRM,它们在销售管理、客户服务及营销自动化等方面展现了巨大潜力,为企业创造了更多价值。
|
机器学习/深度学习 PyTorch 测试技术
PyTorch实战:图像分类任务的实现与优化
【4月更文挑战第17天】本文介绍了使用PyTorch实现图像分类任务的步骤,包括数据集准备(如使用CIFAR-10数据集)、构建简单的CNN模型、训练与优化模型以及测试模型性能。在训练过程中,使用了交叉熵损失和SGD优化器。此外,文章还讨论了提升模型性能的策略,如调整模型结构、数据增强、正则化和利用预训练模型。通过本文,读者可掌握基础的PyTorch图像分类实践。
|
并行计算 安全 Linux
如何设置环境变量KMP_DUPLICATE_LIB_OK=TRUE
【5月更文挑战第25天】如何设置环境变量KMP_DUPLICATE_LIB_OK=TRUE
1207 0
|
人工智能 Oracle 关系型数据库
【AI Agent系列】【LangGraph】0. 快速上手:协同LangChain,LangGraph帮你用图结构轻松构建多智能体应用
【AI Agent系列】【LangGraph】0. 快速上手:协同LangChain,LangGraph帮你用图结构轻松构建多智能体应用
2006 0