搭建小型ViT网络构架进行分类任务(Pytorch)

简介: 搭建小型ViT网络构架进行分类任务(Pytorch)

前言


  在这里我们不过多叙述原理,为了实验的便捷性,我们选择最为常见的MNIST数据集,Demo的构架我们采用我往期的动手撸个MNIST分类(CPU版本+GPU版本) 中GPU版本作为母版



Vision Transformers 大体框架


  为了能够使用自己的ViT模型应用到MNIST分类中去(替换class Net(nn.Module) 模块)可搭建如下框架:

class ViTNet(nn.Module)
    def __init__(self):
        super(ViTNet,self).__init__()
    def forward
        pass



Vision Transformers 细节架构


  在PyTorch中有大量的DL架构都提供Autograd计算,因此我们在Vision Transformers 模型中只需要着重在向前传递过程中花费精力即可;由于我在训练框架中定义了模型的优化器,PyTorch框架能够反向传播梯度并训练模型的参数。


我们将以下的五个重要步骤搭建出符合(MNIST)的网络结构:


第一步:Patchifying 和线性映射:


  Transformer 编码器一开始主要用于NLP这种序列化数据,将它用于CV领域的第一步要处理的是“序列化”图像,这里的处理方式是将一张图像分解成多个子图像,将每个子图像映射成一个向量。


  在MNIST数据集上,我们将每个(1x28x28)的图像分成7x7块,每块大小是4x4(如果不能完全整除分块,需要对图像padding填充),这样我们能从单个图像中获得49个子图像。将原图重塑成:


  (N, PxP, HxC/P x WxC/P) = (N, 7x7, 4x4) = (N, 49, 16)请注意,虽然每个子图大小为 1x4x4 ,但我们将其展平为 16 维向量。此外,MNIST只有一个颜色通道。如果有多个颜色通道,它们也会被展平到矢量中。


  我们得到展平后的patches即向量,通过一个线性映射来改变维度,线性映射可以映射到任意向量大小,我们向类构造函数添加一个hidden_d参数,用于“隐藏维度”。这里,使用隐藏维度为8,这样我们将每个 16 维patch映射到一个 8 维patch


第二步:添加分类标记


  在隐含层后,为了完成MNIST分类任务,必不可少的是添加分类的标记,在模型中添加一个参数将我们的Tensor(N,49,8)转变为Tensor(N,50,8);这里大家需要注意的一个地方是分类标记需要放在每个序列的第一个标记位。在完成MLP时,需要对应到相应的位置上。


第三步:添加位置编码


   紧接上一步,我们标记完成后需要进行添加位置编码,然而这块的理论性较强,强烈建议大家观摩transformer模型中的位置表明输出,这里我们就简化了,采用sin和cos替代。这里需要注意的地方是我们在第二部中转换完的Tensor(N,50,8),此时我们应该重复(50,8)的位置编码矩阵N次。



第四步:LN, MSA和残差连接


   这步较为复杂,我们在对tokens做归一化没然后采用多头注意力机制,最后添加一个残差连接输出。


LN:通过LN运行Tensor(N,50,8)后,每个50x8 矩阵的均值是0,标准差位1,维度保持不变。


   多头自注意力:对于每一张图像,都希望它能参与每个patch并在其中更新。在这里我不做过多注释,大家可参考MSA计算过程。


   残差连接:将添加一个残差连接,它将我们的原始Tensor(N,50,8)添加到在 LN 和 MSA 之后获得的 (N, 50, 8)。如果我们现在通过我们的模型运行MNIST的随机 (3, 1, 28, 28) 图像,我们仍然会得到形状为 (3, 50, 8) 的结果。



第五步:LN,MLP 和残差连接后进行MLP分类:


   这里就开始搭积木了。我们可以从 N 个序列中只提取分类标记(第一个标记),与添加分类标签的位置对应,并使用每个标记得到 N 个分类。


   由于我们决定每个标记是一个 8 维向量,并且由于我们有 10 个可能的数字,我们可以将分类 MLP 实现为一个简单的 8x10 矩阵,并使用 SoftMax 函数激活。




相关文章
|
6天前
|
存储 人工智能 应用服务中间件
Web应用是一种通过互联网浏览器和网络技术在互联网上执行任务的计算机程序
【5月更文挑战第30天】Web应用是一种通过互联网浏览器和网络技术在互联网上执行任务的计算机程序
22 2
|
19天前
|
机器学习/深度学习 JSON PyTorch
图神经网络入门示例:使用PyTorch Geometric 进行节点分类
本文介绍了如何使用PyTorch处理同构图数据进行节点分类。首先,数据集来自Facebook Large Page-Page Network,包含22,470个页面,分为四类,具有不同大小的特征向量。为训练神经网络,需创建PyTorch Data对象,涉及读取CSV和JSON文件,处理不一致的特征向量大小并进行归一化。接着,加载边数据以构建图。通过`Data`对象创建同构图,之后数据被分为70%训练集和30%测试集。训练了两种模型:MLP和GCN。GCN在测试集上实现了80%的准确率,优于MLP的46%,展示了利用图信息的优势。
27 1
|
20天前
|
机器学习/深度学习 PyTorch 算法框架/工具
神经网络基本概念以及Pytorch实现,多线程编程面试题
神经网络基本概念以及Pytorch实现,多线程编程面试题
|
21天前
|
机器学习/深度学习 并行计算 算法
MATLAB|【免费】概率神经网络的分类预测--基于PNN的变压器故障诊断
MATLAB|【免费】概率神经网络的分类预测--基于PNN的变压器故障诊断
|
21天前
|
机器学习/深度学习 PyTorch 算法框架/工具
Python用GAN生成对抗性神经网络判别模型拟合多维数组、分类识别手写数字图像可视化
Python用GAN生成对抗性神经网络判别模型拟合多维数组、分类识别手写数字图像可视化
|
21天前
|
机器学习/深度学习 数据采集 算法
Python对中国电信消费者特征预测:随机森林、朴素贝叶斯、神经网络、最近邻分类、逻辑回归、支持向量回归(SVR)
Python对中国电信消费者特征预测:随机森林、朴素贝叶斯、神经网络、最近邻分类、逻辑回归、支持向量回归(SVR)
|
21天前
|
人工智能 数据可视化
【数据分享】维基百科Wiki负面有害评论(网络暴力)文本数据多标签分类挖掘可视化
【数据分享】维基百科Wiki负面有害评论(网络暴力)文本数据多标签分类挖掘可视化
|
21天前
|
机器学习/深度学习 算法 TensorFlow
【视频】神经网络正则化方法防过拟合和R语言CNN分类手写数字图像数据MNIST|数据分享
【视频】神经网络正则化方法防过拟合和R语言CNN分类手写数字图像数据MNIST|数据分享
|
21天前
|
SQL 安全 测试技术
2021年职业院校技能大赛“网络安全”项目 江西省比赛任务书—B模块
B模块涵盖安全事件响应和网络数据取证,涉及多项应用安全挑战。任务包括使用nmap扫描靶机、弱口令登录、生成反弹木马、权限验证、系统内核版本检查、漏洞源码利用、文件名和内容提取等。此外,还有Linux渗透测试,要求访问特定目录下的文件并提取内容。应用服务漏洞扫描涉及服务版本探测、敏感文件发现、私钥解密、权限提升等。SQL注入测试需利用Nmap扫描端口,进行SQL注入并获取敏感信息。应急响应任务包括处理木马、删除恶意用户、修复启动项和清除服务器上的木马。流量分析涉及Wireshark数据包分析,查找黑客IP、枚举测试、服务破解等。渗透测试任务涵盖系统服务扫描、数据库管理、漏洞利用模块搜索等。
37 0
|
21天前
|
监控 安全 网络安全
2021年职业院校技能大赛“网络安全”项目 江西省比赛任务书—A模块
该文档是关于企业服务器系统安全加固的任务描述,包括A模块的六个部分:登录安全、Web安全、流量保护与事件监控、防火墙策略、Windows和Linux操作系统安全配置。任务涉及设置密码和登录策略、启用安全日志、限制非法访问、调整防火墙规则、加强操作系统安全和优化服务配置等,以提升网络安全防御能力。每个部分都有具体的配置截图要求,并需按照指定格式保存提交。
17 0