搭建小型ViT网络构架进行分类任务(Pytorch)

简介: 搭建小型ViT网络构架进行分类任务(Pytorch)

前言


  在这里我们不过多叙述原理,为了实验的便捷性,我们选择最为常见的MNIST数据集,Demo的构架我们采用我往期的动手撸个MNIST分类(CPU版本+GPU版本) 中GPU版本作为母版



Vision Transformers 大体框架


  为了能够使用自己的ViT模型应用到MNIST分类中去(替换class Net(nn.Module) 模块)可搭建如下框架:

class ViTNet(nn.Module)
    def __init__(self):
        super(ViTNet,self).__init__()
    def forward
        pass



Vision Transformers 细节架构


  在PyTorch中有大量的DL架构都提供Autograd计算,因此我们在Vision Transformers 模型中只需要着重在向前传递过程中花费精力即可;由于我在训练框架中定义了模型的优化器,PyTorch框架能够反向传播梯度并训练模型的参数。


我们将以下的五个重要步骤搭建出符合(MNIST)的网络结构:


第一步:Patchifying 和线性映射:


  Transformer 编码器一开始主要用于NLP这种序列化数据,将它用于CV领域的第一步要处理的是“序列化”图像,这里的处理方式是将一张图像分解成多个子图像,将每个子图像映射成一个向量。


  在MNIST数据集上,我们将每个(1x28x28)的图像分成7x7块,每块大小是4x4(如果不能完全整除分块,需要对图像padding填充),这样我们能从单个图像中获得49个子图像。将原图重塑成:


  (N, PxP, HxC/P x WxC/P) = (N, 7x7, 4x4) = (N, 49, 16)请注意,虽然每个子图大小为 1x4x4 ,但我们将其展平为 16 维向量。此外,MNIST只有一个颜色通道。如果有多个颜色通道,它们也会被展平到矢量中。


  我们得到展平后的patches即向量,通过一个线性映射来改变维度,线性映射可以映射到任意向量大小,我们向类构造函数添加一个hidden_d参数,用于“隐藏维度”。这里,使用隐藏维度为8,这样我们将每个 16 维patch映射到一个 8 维patch


第二步:添加分类标记


  在隐含层后,为了完成MNIST分类任务,必不可少的是添加分类的标记,在模型中添加一个参数将我们的Tensor(N,49,8)转变为Tensor(N,50,8);这里大家需要注意的一个地方是分类标记需要放在每个序列的第一个标记位。在完成MLP时,需要对应到相应的位置上。


第三步:添加位置编码


   紧接上一步,我们标记完成后需要进行添加位置编码,然而这块的理论性较强,强烈建议大家观摩transformer模型中的位置表明输出,这里我们就简化了,采用sin和cos替代。这里需要注意的地方是我们在第二部中转换完的Tensor(N,50,8),此时我们应该重复(50,8)的位置编码矩阵N次。



第四步:LN, MSA和残差连接


   这步较为复杂,我们在对tokens做归一化没然后采用多头注意力机制,最后添加一个残差连接输出。


LN:通过LN运行Tensor(N,50,8)后,每个50x8 矩阵的均值是0,标准差位1,维度保持不变。


   多头自注意力:对于每一张图像,都希望它能参与每个patch并在其中更新。在这里我不做过多注释,大家可参考MSA计算过程。


   残差连接:将添加一个残差连接,它将我们的原始Tensor(N,50,8)添加到在 LN 和 MSA 之后获得的 (N, 50, 8)。如果我们现在通过我们的模型运行MNIST的随机 (3, 1, 28, 28) 图像,我们仍然会得到形状为 (3, 50, 8) 的结果。



第五步:LN,MLP 和残差连接后进行MLP分类:


   这里就开始搭积木了。我们可以从 N 个序列中只提取分类标记(第一个标记),与添加分类标签的位置对应,并使用每个标记得到 N 个分类。


   由于我们决定每个标记是一个 8 维向量,并且由于我们有 10 个可能的数字,我们可以将分类 MLP 实现为一个简单的 8x10 矩阵,并使用 SoftMax 函数激活。




相关文章
|
2月前
|
监控 安全 Linux
在 Linux 系统中,网络管理是重要任务。本文介绍了常用的网络命令及其适用场景
在 Linux 系统中,网络管理是重要任务。本文介绍了常用的网络命令及其适用场景,包括 ping(测试连通性)、traceroute(跟踪路由路径)、netstat(显示网络连接信息)、nmap(网络扫描)、ifconfig 和 ip(网络接口配置)。掌握这些命令有助于高效诊断和解决网络问题,保障网络稳定运行。
106 2
|
12天前
|
机器学习/深度学习 算法 PyTorch
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
软演员-评论家算法(Soft Actor-Critic, SAC)是深度强化学习领域的重要进展,基于最大熵框架优化策略,在探索与利用之间实现动态平衡。SAC通过双Q网络设计和自适应温度参数,提升了训练稳定性和样本效率。本文详细解析了SAC的数学原理、网络架构及PyTorch实现,涵盖演员网络的动作采样与对数概率计算、评论家网络的Q值估计及其损失函数,并介绍了完整的SAC智能体实现流程。SAC在连续动作空间中表现出色,具有高样本效率和稳定的训练过程,适合实际应用场景。
58 7
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
|
23天前
|
人工智能 搜索推荐 决策智能
不靠更复杂的策略,仅凭和大模型训练对齐,零样本零经验单LLM调用,成为网络任务智能体新SOTA
近期研究通过调整网络智能体的观察和动作空间,使其与大型语言模型(LLM)的能力对齐,显著提升了基于LLM的网络智能体性能。AgentOccam智能体在WebArena基准上超越了先前方法,成功率提升26.6个点(+161%)。该研究强调了与LLM训练目标一致的重要性,为网络任务自动化提供了新思路,但也指出其性能受限于LLM能力及任务复杂度。论文链接:https://arxiv.org/abs/2410.13825。
51 12
|
28天前
|
机器学习/深度学习 算法 PyTorch
基于Pytorch Gemotric在昇腾上实现GraphSage图神经网络
本文详细介绍了如何在昇腾平台上使用PyTorch实现GraphSage算法,在CiteSeer数据集上进行图神经网络的分类训练。内容涵盖GraphSage的创新点、算法原理、网络架构及实战代码分析,通过采样和聚合方法高效处理大规模图数据。实验结果显示,模型在CiteSeer数据集上的分类准确率达到66.5%。
|
3月前
|
网络协议
计算机网络的分类
【10月更文挑战第11天】 计算机网络可按覆盖范围(局域网、城域网、广域网)、传输技术(有线、无线)、拓扑结构(星型、总线型、环型、网状型)、使用者(公用、专用)、交换方式(电路交换、分组交换)和服务类型(面向连接、无连接)等多种方式进行分类,每种分类方式揭示了网络的不同特性和应用场景。
|
1月前
|
机器学习/深度学习 Serverless 索引
分类网络中one-hot编码的作用
在分类任务中,使用神经网络时,通常需要将类别标签转换为一种合适的输入格式。这时候,one-hot编码(one-hot encoding)是一种常见且有效的方法。one-hot编码将类别标签表示为向量形式,其中只有一个元素为1,其他元素为0。
47 2
|
2月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
利用Python和TensorFlow构建简单神经网络进行图像分类
利用Python和TensorFlow构建简单神经网络进行图像分类
72 3
|
3月前
|
机器学习/深度学习 Serverless 索引
分类网络中one-hot的作用
在分类任务中,使用神经网络时,通常需要将类别标签转换为一种合适的输入格式。这时候,one-hot编码(one-hot encoding)是一种常见且有效的方法。one-hot编码将类别标签表示为向量形式,其中只有一个元素为1,其他元素为0。
87 3
|
3月前
|
机器学习/深度学习 存储 自然语言处理
从理论到实践:如何使用长短期记忆网络(LSTM)改善自然语言处理任务
【10月更文挑战第7天】随着深度学习技术的发展,循环神经网络(RNNs)及其变体,特别是长短期记忆网络(LSTMs),已经成为处理序列数据的强大工具。在自然语言处理(NLP)领域,LSTM因其能够捕捉文本中的长期依赖关系而变得尤为重要。本文将介绍LSTM的基本原理,并通过具体的代码示例来展示如何在实际的NLP任务中应用LSTM。
281 4
|
3月前
|
机器学习/深度学习 数据采集 算法
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)
这篇博客文章介绍了如何使用包含多个网络和多种训练策略的框架来完成多目标分类任务,涵盖了从数据准备到训练、测试和部署的完整流程,并提供了相关代码和配置文件。
80 0
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)

热门文章

最新文章