搭建小型ViT网络构架进行分类任务(Pytorch)

简介: 搭建小型ViT网络构架进行分类任务(Pytorch)

前言


  在这里我们不过多叙述原理,为了实验的便捷性,我们选择最为常见的MNIST数据集,Demo的构架我们采用我往期的动手撸个MNIST分类(CPU版本+GPU版本) 中GPU版本作为母版



Vision Transformers 大体框架


  为了能够使用自己的ViT模型应用到MNIST分类中去(替换class Net(nn.Module) 模块)可搭建如下框架:

class ViTNet(nn.Module)
    def __init__(self):
        super(ViTNet,self).__init__()
    def forward
        pass



Vision Transformers 细节架构


  在PyTorch中有大量的DL架构都提供Autograd计算,因此我们在Vision Transformers 模型中只需要着重在向前传递过程中花费精力即可;由于我在训练框架中定义了模型的优化器,PyTorch框架能够反向传播梯度并训练模型的参数。


我们将以下的五个重要步骤搭建出符合(MNIST)的网络结构:


第一步:Patchifying 和线性映射:


  Transformer 编码器一开始主要用于NLP这种序列化数据,将它用于CV领域的第一步要处理的是“序列化”图像,这里的处理方式是将一张图像分解成多个子图像,将每个子图像映射成一个向量。


  在MNIST数据集上,我们将每个(1x28x28)的图像分成7x7块,每块大小是4x4(如果不能完全整除分块,需要对图像padding填充),这样我们能从单个图像中获得49个子图像。将原图重塑成:


  (N, PxP, HxC/P x WxC/P) = (N, 7x7, 4x4) = (N, 49, 16)请注意,虽然每个子图大小为 1x4x4 ,但我们将其展平为 16 维向量。此外,MNIST只有一个颜色通道。如果有多个颜色通道,它们也会被展平到矢量中。


  我们得到展平后的patches即向量,通过一个线性映射来改变维度,线性映射可以映射到任意向量大小,我们向类构造函数添加一个hidden_d参数,用于“隐藏维度”。这里,使用隐藏维度为8,这样我们将每个 16 维patch映射到一个 8 维patch


第二步:添加分类标记


  在隐含层后,为了完成MNIST分类任务,必不可少的是添加分类的标记,在模型中添加一个参数将我们的Tensor(N,49,8)转变为Tensor(N,50,8);这里大家需要注意的一个地方是分类标记需要放在每个序列的第一个标记位。在完成MLP时,需要对应到相应的位置上。


第三步:添加位置编码


   紧接上一步,我们标记完成后需要进行添加位置编码,然而这块的理论性较强,强烈建议大家观摩transformer模型中的位置表明输出,这里我们就简化了,采用sin和cos替代。这里需要注意的地方是我们在第二部中转换完的Tensor(N,50,8),此时我们应该重复(50,8)的位置编码矩阵N次。



第四步:LN, MSA和残差连接


   这步较为复杂,我们在对tokens做归一化没然后采用多头注意力机制,最后添加一个残差连接输出。


LN:通过LN运行Tensor(N,50,8)后,每个50x8 矩阵的均值是0,标准差位1,维度保持不变。


   多头自注意力:对于每一张图像,都希望它能参与每个patch并在其中更新。在这里我不做过多注释,大家可参考MSA计算过程。


   残差连接:将添加一个残差连接,它将我们的原始Tensor(N,50,8)添加到在 LN 和 MSA 之后获得的 (N, 50, 8)。如果我们现在通过我们的模型运行MNIST的随机 (3, 1, 28, 28) 图像,我们仍然会得到形状为 (3, 50, 8) 的结果。



第五步:LN,MLP 和残差连接后进行MLP分类:


   这里就开始搭积木了。我们可以从 N 个序列中只提取分类标记(第一个标记),与添加分类标签的位置对应,并使用每个标记得到 N 个分类。


   由于我们决定每个标记是一个 8 维向量,并且由于我们有 10 个可能的数字,我们可以将分类 MLP 实现为一个简单的 8x10 矩阵,并使用 SoftMax 函数激活。




相关文章
|
2月前
|
机器学习/深度学习 算法 调度
14种智能算法优化BP神经网络(14种方法)实现数据预测分类研究(Matlab代码实现)
14种智能算法优化BP神经网络(14种方法)实现数据预测分类研究(Matlab代码实现)
294 0
|
2月前
|
机器学习/深度学习 传感器 算法
【无人车路径跟踪】基于神经网络的数据驱动迭代学习控制(ILC)算法,用于具有未知模型和重复任务的非线性单输入单输出(SISO)离散时间系统的无人车的路径跟踪(Matlab代码实现)
【无人车路径跟踪】基于神经网络的数据驱动迭代学习控制(ILC)算法,用于具有未知模型和重复任务的非线性单输入单输出(SISO)离散时间系统的无人车的路径跟踪(Matlab代码实现)
170 2
|
25天前
|
机器学习/深度学习 数据采集 存储
概率神经网络的分类预测--基于PNN的变压器故障诊断(Matlab代码实现)
概率神经网络的分类预测--基于PNN的变压器故障诊断(Matlab代码实现)
189 0
|
3月前
|
机器学习/深度学习 数据采集 运维
匹配网络处理不平衡数据集的6种优化策略:有效提升分类准确率
匹配网络是一种基于度量的元学习方法,通过计算查询样本与支持集样本的相似性实现分类。其核心依赖距离度量函数(如余弦相似度),并引入注意力机制对特征维度加权,提升对关键特征的关注能力,尤其在处理复杂或噪声数据时表现出更强的泛化性。
193 6
匹配网络处理不平衡数据集的6种优化策略:有效提升分类准确率
|
2月前
|
机器学习/深度学习 算法 PyTorch
【Pytorch框架搭建神经网络】基于DQN算法、优先级采样的DQN算法、DQN + 人工势场的避障控制研究(Python代码实现)
【Pytorch框架搭建神经网络】基于DQN算法、优先级采样的DQN算法、DQN + 人工势场的避障控制研究(Python代码实现)
|
2月前
|
安全 网络性能优化 网络虚拟化
网络交换机分类与功能解析
接入交换机(ASW)连接终端设备,提供高密度端口与基础安全策略;二层交换机(LSW)基于MAC地址转发数据,构成局域网基础;汇聚交换机(DSW)聚合流量并实施VLAN路由、QoS等高级策略;核心交换机(CSW)作为网络骨干,具备高性能、高可靠性的高速转发能力;中间交换机(ISW)可指汇聚层设备或刀片服务器内交换模块。典型流量路径为:终端→ASW→DSW/ISW→CSW,分层架构提升网络扩展性与管理效率。(238字)
700 0
|
2月前
|
机器学习/深度学习 算法 PyTorch
【DQN实现避障控制】使用Pytorch框架搭建神经网络,基于DQN算法、优先级采样的DQN算法、DQN + 人工势场实现避障控制研究(Matlab、Python实现)
【DQN实现避障控制】使用Pytorch框架搭建神经网络,基于DQN算法、优先级采样的DQN算法、DQN + 人工势场实现避障控制研究(Matlab、Python实现)
122 0
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
基于Pytorch 在昇腾上实现GCN图神经网络
本文详细讲解了如何在昇腾平台上使用PyTorch实现图神经网络(GCN)对Cora数据集进行分类训练。内容涵盖GCN背景、模型特点、网络架构剖析及实战分析。GCN通过聚合邻居节点信息实现“卷积”操作,适用于非欧氏结构数据。文章以两层GCN模型为例,结合Cora数据集(2708篇科学出版物,1433个特征,7种类别),展示了从数据加载到模型训练的完整流程。实验在NPU上运行,设置200个epoch,最终测试准确率达0.8040,内存占用约167M。
基于Pytorch 在昇腾上实现GCN图神经网络
|
6月前
|
PyTorch 调度 算法框架/工具
阿里云PAI-DLC任务Pytorch launch_agent Socket Timeout问题源码分析
DLC任务Pytorch launch_agent Socket Timeout问题源码分析与解决方案
323 18
阿里云PAI-DLC任务Pytorch launch_agent Socket Timeout问题源码分析
|
6月前
|
机器学习/深度学习 搜索推荐 PyTorch
基于昇腾用PyTorch实现CTR模型DIN(Deep interest Netwok)网络
本文详细讲解了如何在昇腾平台上使用PyTorch训练推荐系统中的经典模型DIN(Deep Interest Network)。主要内容包括:DIN网络的创新点与架构剖析、Activation Unit和Attention模块的实现、Amazon-book数据集的介绍与预处理、模型训练过程定义及性能评估。通过实战演示,利用Amazon-book数据集训练DIN模型,最终评估其点击率预测性能。文中还提供了代码示例,帮助读者更好地理解每个步骤的实现细节。

热门文章

最新文章

推荐镜像

更多