Pytorch CIFAR10图像分类 Swin Transformer篇(二)

简介: Pytorch CIFAR10图像分类 Swin Transformer篇(二)

Pytorch CIFAR10图像分类 Swin Transformer篇(一):https://developer.aliyun.com/article/1410617

Shifted Window Attention

前面的Window Attention是在每个窗口下计算注意力的,为了更好的和其他window进行信息交互Swin Transformer不引入了shifted window操作。



3773e1c8a9554ee19ee014f1e4cc45d3.png

左边是没有重叠的Window Attention,而右边则是将窗口进行移位的Shift Window Attention。可以看到移位后的窗口包含了原本相邻窗口的元素。但这也引入了一个新问题,即window的个数翻倍了由原本四个窗口变成了9个窗口。

在实际代码里,我们是通过对特征图移位,并给Attention设置mask来间接实现的。能在保持原有的window个数下,最后的计算结果等价。

class BasicLayer(nn.Module):
    def __init__(self,
                 dim,
                 depth,
                 n_heads,
                 window_size,
                 mlp_ratio=4,
                 qkv_bias=True,
                 proj_dropout=0.,
                 attn_dropout=0.,
                 dropout=0.,
                 norm_layer=nn.LayerNorm,
                 downsample=None):
        super(BasicLayer, self).__init__()
        self.dim = dim
        self.depth = depth
        self.window_size = window_size
        # 窗口向右和向下的移动数为窗口宽度除以2向下取整
        self.shift_size = window_size // 2
        # 按照每个Stage的深度堆叠若干Block
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim, n_heads, window_size,
                                 0 if (i % 2 == 0) else self.shift_size,
                                 mlp_ratio, qkv_bias, proj_dropout, attn_dropout,
                                 dropout[i] if isinstance(
                                     dropout, list) else dropout,
                                 norm_layer)
            for i in range(depth)])
        self.downsample = downsample(dim=dim, norm_layer=norm_layer) if downsample else None
    def forward(self, x, h, w):
        attn_mask = self.create_mask(x, h, w)
        for blk in self.blocks:
            blk.h, blk.w = h, w
            x = blk(x, attn_mask)
        if self.downsample is not None:
            x = self.downsample(x, h, w)
            # 如果是奇数,相当于做padding后除以2
            # 如果是偶数,相当于直接除以2
            h, w = (h+1) // 2, (w+1) // 2
        return x, h, w
    def create_mask(self, x, h, w):
        # 保证hp,wp是window_size的整数倍
        hp = int(np.ceil(h/self.window_size)) * self.window_size
        wp = int(np.ceil(w/self.window_size)) * self.window_size
        # (1,hp,wp,1)
        img_mask = torch.zeros((1, hp, wp, 1), device=x.device)
        # 将feature map分割成9个区域
        # 例如,对于9x9图片
        # 若window_size=3, shift_size=3//2=1
        # 得到slices为([0,-3],[-3,-1],[-1,])
        # 于是h从0至-4(不到-3),w从0至-4
        # 即(0,0)(-4,-4)围成的矩形为第1个区域
        # h从0至-4,w从-3至-2
        # 即(0,-3)(-4,-2)围成的矩形为第2个区域...
        # h\w 0 1 2 3 4 5 6 7 8
        # --+--------------------
        # 0 | 0 0 0 0 0 0 1 1 2
        # 1 | 0 0 0 0 0 0 1 1 2
        # 2 | 0 0 0 0 0 0 1 1 2
        # 3 | 0 0 0 0 0 0 1 1 2
        # 4 | 0 0 0 0 0 0 1 1 2
        # 5 | 0 0 0 0 0 0 1 1 2
        # 6 | 3 3 3 3 3 3 4 4 5
        # 7 | 3 3 3 3 3 3 4 4 5
        # 8 | 6 6 6 6 6 6 7 7 8
        # 这样在每个窗口内,相同数字的区域都是连续的
        slices = (slice(0, -self.window_size),
                  slice(-self.window_size, -self.shift_size),
                  slice(-self.shift_size, None))
        cnt = 0
        for h in slices:
            for w in slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1
        # (1,hp,wp,1) -> (n_windows,m,m,1) m表示window_size
        mask_windows = window_partition(img_mask, self.window_size)
        # (n_windows,m,m,1) -> (n_windows,m*m)
        mask_windows = mask_windows.view(-1,
                                         self.window_size * self.window_size)
        # (n_windows,1,m*m) - (n_windows,m*m,1)
        # -> (n_windows,m*m,m*m)
        # 以window
        # [[4 4 5]
        #  [4 4 5]
        #  [7 7 8]]
        # 为例
        # 展平后为 [4,4,5,4,4,5,7,7,8]
        # [[4,4,5,4,4,5,7,7,8]] - [[4]
        #                          [4]
        #                          [5]
        #                          [4]
        #                          [4]
        #                          [5]
        #                          [7]
        #                          [7]
        #                          [8]]
        # -> [[0,0,-,0,0,-,-,-,-]
        #     [0,0,-,0,0,-,-,-,-]
        #     [...]]
        # 于是有同样数字的区域为0,不同数字的区域为非0
        # attn_mask[1,3]即为窗口的第3个元素(1,0)和第1个元素(0,1)是否相同
        # 若相同,则值为0,否则为非0
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        # 将非0的元素设为-100
        attn_mask = (attn_mask
                     .masked_fill(attn_mask != 0, float(-100.))
                     .masked_fill(attn_mask == 0, float(0.)))
        return attn_mask

Swin Transformer Block

Swin Transformer Block是该算法的核心点,它由窗口多头自注意层 (window multi-head self-attention,W-MSA) 和移位窗口多头自注意层 (shifted-window multi-head self-attention, SW-MSA)组成,如图所示。由于这个原因,Swin Transformer的层数要为2的整数倍,一层提供给W-MSA,一层提供给SW-MSA。


202a9debeec8e09ea7552055aef54653.png

class SwinTransformerBlock(nn.Module):
    def __init__(self,
                 dim,
                 n_heads,
                 window_size=7,
                 shift_size=0,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 proj_dropout=0.,
                 attn_dropout=0.,
                 dropout=0.,
                 norm_layer=nn.LayerNorm):
        super(SwinTransformerBlock, self).__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(dim, window_size, n_heads,
                                    qkv_bias, attn_dropout, proj_dropout)
        self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        self.mlp = MLP(in_features=dim,
                       hid_features=dim*mlp_ratio,
                       dropout=proj_dropout)
    def forward(self, x, attn_mask):
        h, w = self.h, self.w
        b, _, c = x.shape
        shortcut = x
        x = self.norm1(x)
        x = x.view(b, h, w, c)
        pad_r = (self.window_size - w % self.window_size) % self.window_size
        pad_b = (self.window_size - h % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
        _, hp, wp, _ = x.shape
        if self.shift_size > 0:
            shifted_x = torch.roll(
                x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x
            attn_mask = None
        # (n_windows*b,m,m,c)
        x_windows = window_partition(shifted_x, self.window_size)
        # (n_windows*b,m*m,c)
        x_windows = x_windows.view(-1, self.window_size*self.window_size, c)
        attn_windows = self.attn(x_windows, attn_mask)
        attn_windows = attn_windows.view(-1,
                                         self.window_size, self.window_size, c)
        shifted_x = window_reverse(attn_windows, self.window_size, hp, wp)
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(
                self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        if pad_r > 0 or pad_b > 0:
            x = x[:, :h, :w, :].contiguous()
        x = x.view(b, h*w, c)
        x = shortcut + self.dropout(x)
        x = x + self.dropout(self.mlp(self.norm2(x)))
        return x

整体流程如下

先对特征图进行LayerNorm.

通过 self.shiftsize 决定是否需要对特征图进行shift

然后将特征图切成一个个窗口

计算Attention,通过 self.attn mask 来区分Window Attention 还是 Shift Window Attention

将各个窗口合并回来

如果之前有做shift操作,此时进行 reverse shift,把之前的shift操作恢复.

做dropout和残差连接

再通过一层LayerNorm+全连接层,以及dropout和残差连接

Swin Transformer

我们可以来看看整体的网络架构

b24182d8f2fadbea0ca908513f003ee5.png

整个模型采取层次化的设计,一共包含4个Stage,每个stage都会缩小输入特征图的分辨率,像CNN样逐层扩大感受野。

在输入开始的时候,做了一个 Patch Embedding,将图片切成一个个图块,并嵌入到Embedding。

在每个Stage里,由 Patch Merging 和多个Block组成

其中 Patch Merging 模块主要在每个Stage一开始降低图片分辨率

而Block具体结构如右图所示,主要是 LayerNorm ,MLP,Window Attention 和 Shifted windowAttention 组成

class SwinTransformer(nn.Module):
    def __init__(self,
                 patch_size=4,
                 in_c=3,
                 n_classes=1000,
                 embed_dim=96,
                 depths=(2, 2, 6, 2),
                 n_heads=(3, 6, 12, 24),
                 window_size=7,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 proj_dropout=0.,
                 attn_dropout=0.,
                 dropout=0.,
                 norm_layer=nn.LayerNorm,
                 patch_norm=True):
        super(SwinTransformer, self).__init__()
        self.n_classes = n_classes
        self.n_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        # Stage4输出的channels,即embed_dim*8
        self.n_features = int(embed_dim * 2**(self.n_layers-1))
        self.mlp_ratio = mlp_ratio
        self.patch_embed = PatchEmbed(patch_size, in_c, embed_dim,
                                      norm_layer if self.patch_norm else None)
        self.pos_drop = nn.Dropout(proj_dropout)
        # 根据深度递增dropout
        dpr = [x.item() for x in torch.linspace(0, dropout, sum(depths))]
        self.layers = nn.ModuleList()
        for i in range(self.n_layers):
            layers = BasicLayer(int(embed_dim*2**i), depths[i], n_heads[i],
                                window_size, mlp_ratio, qkv_bias,
                                proj_dropout, attn_dropout, 
                                dpr[sum(depths[:i]):sum(depths[:i+1])],
                                norm_layer, PatchMerging if i < self.n_layers-1 else None)
            self.layers.append(layers)
        self.norm = norm_layer(self.n_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(
            self.n_features, n_classes) if n_classes > 0 else nn.Identity()
        self.apply(self._init_weights)
    def forward(self, x):
        x, h, w = self.patch_embed(x)
        x = self.pos_drop(x)
        for layer in self.layers:
            x, h, w = layer(x, h, w)
        # (b,l,c)
        x = self.norm(x)
        # (b,l,c) -> (b,c,l) -> (b,c,1)
        x = self.avgpool(x.transpose(1, 2))
        # (b,c,1) -> (b,c)
        x = torch.flatten(x, 1)
        # (b,c) -> (b,n_classes)
        x = self.head(x)
        return x
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.)

Swin Transformer共提出了4个不同尺寸的模型,它们的区别在于隐层节点的长度,每个stage的层数,多头自注意力机制的头的个数。

1fad13b5a6196daf83c37bd57ffc0a9c.png

def swin_t(hidden_dim=96, layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs):
    return SwinTransformer(embed_dim=hidden_dim, depths=layers, n_heads=heads, **kwargs)
def swin_s(hidden_dim=96, layers=(2, 2, 18, 2), heads=(3, 6, 12, 24), **kwargs):
    return SwinTransformer(embed_dim=hidden_dim, depths=layers, n_heads=heads, **kwargs)
def swin_b(hidden_dim=128, layers=(2, 2, 18, 2), heads=(4, 8, 16, 32), **kwargs):
    return SwinTransformer(embed_dim=hidden_dim, depths=layers, n_heads=heads, **kwargs)
def swin_l(hidden_dim=192, layers=(2, 2, 18, 2), heads=(6, 12, 24, 48), **kwargs):
    return SwinTransformer(embed_dim=hidden_dim, depths=layers, n_heads=heads, **kwargs)

由于训练的图片的shape是3x32x32,所以这里跟原始的有一点点不一样,不过这一部分大家也可以尝试调一下自己的选择,我这里借鉴了Swin-T,不过去掉了最后一个ayer的部分进行测试CIFAR10

def swin_cifar(hidden_dim=192, layers=(2,6,2), heads=(3,6,12), **kwargs):
    return SwinTransformer(embed_dim=hidden_dim, depths=layers, n_heads=heads,  **kwargs)
net = swin_cifar(patch_size=2, n_classes=10,mlp_ratio=1).to(device)

summary查看网络

我们可以通过summary来看到,模型的维度的变化,这个也是和论文是匹配的,经过层后shape的变化,是否最后也是输出(batch,shape),但是这里好像会有一些bug,暂且先不调哈哈,不影响后面

# summary(net,(2,3,32,32))

成功后我们也可以从我们summary可以看到,我们输入的是 (batch,3,32,32) 的张量,并且这里也能看到每一层后我们的图像输出大小的变化,最后输出10个参数,再通过softmax函数就可以得到我们每个类别的概率了。

我们也可以打印出我们的模型观察一下

SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(2, 2), stride=(2, 2))
    (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0): BasicLayer(
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=192, out_features=576, bias=True)
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=192, out_features=192, bias=True)
            (proj_dropout): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (dropout): Identity()
          (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (mlp): MLP(
            (fc1): Linear(in_features=192, out_features=192, bias=True)
            (act): GELU()
            (drop1): Dropout(p=0.0, inplace=False)
            (fc2): Linear(in_features=192, out_features=192, bias=True)
            (drop2): Dropout(p=0.0, inplace=False)
          )
        )
        (1): SwinTransformerBlock(
          (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=192, out_features=576, bias=True)
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=192, out_features=192, bias=True)
            (proj_dropout): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (dropout): Identity()
          (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (mlp): MLP(
            (fc1): Linear(in_features=192, out_features=192, bias=True)
            (act): GELU()
            (drop1): Dropout(p=0.0, inplace=False)
            (fc2): Linear(in_features=192, out_features=192, bias=True)
            (drop2): Dropout(p=0.0, inplace=False)
          )
        )
      )
      (downsample): PatchMerging(
        (reduction): Linear(in_features=768, out_features=384, bias=False)
        (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      )
    )
    (1): BasicLayer(
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_dropout): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (dropout): Identity()
          (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): MLP(
            (fc1): Linear(in_features=384, out_features=384, bias=True)
            (act): GELU()
            (drop1): Dropout(p=0.0, inplace=False)
            (fc2): Linear(in_features=384, out_features=384, bias=True)
            (drop2): Dropout(p=0.0, inplace=False)
          )
        )
        (1): SwinTransformerBlock(
          (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_dropout): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (dropout): Identity()
          (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): MLP(
            (fc1): Linear(in_features=384, out_features=384, bias=True)
            (act): GELU()
            (drop1): Dropout(p=0.0, inplace=False)
            (fc2): Linear(in_features=384, out_features=384, bias=True)
            (drop2): Dropout(p=0.0, inplace=False)
          )
        )
        (2): SwinTransformerBlock(
          (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_dropout): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (dropout): Identity()
          (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): MLP(
            (fc1): Linear(in_features=384, out_features=384, bias=True)
            (act): GELU()
            (drop1): Dropout(p=0.0, inplace=False)
            (fc2): Linear(in_features=384, out_features=384, bias=True)
            (drop2): Dropout(p=0.0, inplace=False)
          )
        )
        (3): SwinTransformerBlock(
          (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_dropout): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (dropout): Identity()
          (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): MLP(
            (fc1): Linear(in_features=384, out_features=384, bias=True)
            (act): GELU()
            (drop1): Dropout(p=0.0, inplace=False)
            (fc2): Linear(in_features=384, out_features=384, bias=True)
            (drop2): Dropout(p=0.0, inplace=False)
          )
        )
        (4): SwinTransformerBlock(
          (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_dropout): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (dropout): Identity()
          (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): MLP(
            (fc1): Linear(in_features=384, out_features=384, bias=True)
            (act): GELU()
            (drop1): Dropout(p=0.0, inplace=False)
            (fc2): Linear(in_features=384, out_features=384, bias=True)
            (drop2): Dropout(p=0.0, inplace=False)
          )
        )
        (5): SwinTransformerBlock(
          (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_dropout): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (dropout): Identity()
          (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): MLP(
            (fc1): Linear(in_features=384, out_features=384, bias=True)
            (act): GELU()
            (drop1): Dropout(p=0.0, inplace=False)
            (fc2): Linear(in_features=384, out_features=384, bias=True)
            (drop2): Dropout(p=0.0, inplace=False)
          )
        )
      )
      (downsample): PatchMerging(
        (reduction): Linear(in_features=1536, out_features=768, bias=False)
        (norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
      )
    )
    (2): BasicLayer(
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_dropout): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (dropout): Identity()
          (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): MLP(
            (fc1): Linear(in_features=768, out_features=768, bias=True)
            (act): GELU()
            (drop1): Dropout(p=0.0, inplace=False)
            (fc2): Linear(in_features=768, out_features=768, bias=True)
            (drop2): Dropout(p=0.0, inplace=False)
          )
        )
        (1): SwinTransformerBlock(
          (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_dropout): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (dropout): Identity()
          (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): MLP(
            (fc1): Linear(in_features=768, out_features=768, bias=True)
            (act): GELU()
            (drop1): Dropout(p=0.0, inplace=False)
            (fc2): Linear(in_features=768, out_features=768, bias=True)
            (drop2): Dropout(p=0.0, inplace=False)
          )
        )
      )
    )
  )
  (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (avgpool): AdaptiveAvgPool1d(output_size=1)
  (head): Linear(in_features=768, out_features=10, bias=True)
)

测试和定义网络

接下来可以简单测试一下,是否输入后能得到我们的正确的维度shape

test_x = torch.randn(2,3,32,32).to(device)
test_y = net(test_x)
print(test_y.shape)
torch.Size([2, 10])

定义网络和设置类别

net = swin_cifar(patch_size=2, n_classes=10,mlp_ratio=1)

5.定义损失函数和优化器

pytorch将深度学习中常用的优化方法全部封装在torch.optim之中,所有的优化方法都是继承基类optim.Optimizier

损失函数是封装在神经网络工具箱nn中的,包含很多损失函数

这里我使用的是SGD + momentum算法,并且我们损失函数定义为交叉熵函数,除此之外学习策略定义为动态更新学习率,如果5次迭代后,训练的损失并没有下降,那么我们便会更改学习率,会变为原来的0.5倍,最小降低到0.00001

如果想更加了解优化器和学习率策略的话,可以参考以下资料

Pytorch Note15 优化算法1 梯度下降 (Gradient descent varients).

。Pytorch Note16 优化算法2 动量法(Momentum)

。Pytorch Note34 学习率衰减

这里决定迭代10次

import torch.optim as optim
optimizer = optim.Adam(net.parameters(), lr=1e-4, betas=(.5, .999))
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.94 ,patience = 1,min_lr = 0.000001) # 动态更新学习率
# scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[75, 150], gamma=0.5)
import time
epoch = 10

6.训练及可视化(增加TensorBoard可视化)

首先定义模型保存的位置

import os
if not os.path.exists('./model'):
    os.makedirs('./model')
else:
    print('文件已存在')
save_path = './model/Swintransformer.pth'

这次更新了tensorboard的可视化,可以得到更好看的图片,并且能可视化出不错的结果

# 使用tensorboard
from torch.utils.tensorboard import SummaryWriter
os.makedirs("./logs", exist_ok=True)
tbwriter = SummaryWriter(log_dir='./logs/Swintransformer', comment='Swintransformer')  # 使用tensorboard记录中间输出
tbwriter.add_graph(model= net, input_to_model=torch.randn(size=(1, 3, 32, 32)))

如果存在GPU可以选择使用GPU进行运行,并且可以设置并行运算

if device == 'cuda':
    net.to(device)
    net = nn.DataParallel(net) # 使用并行运算

开始训练

我定义了一个train函数,在train函数中进行一个训练,并保存我们训练后的模型,这一部分一定要注意,这里的utils文件是我个人写的,所以需要下载下来

或者可以参考我们的工具函数篇,我还更新了结果和方法,利用tqdm更能可视化我们的结果.

from utils import plot_history
from utils import train
Acc, Loss, Lr = train(net, trainloader, testloader, epoch, optimizer, criterion, scheduler, save_path, tbwriter, verbose = True)
Train Epoch 1/20: 100%|██████████| 781/781 [01:34<00:00,  8.27it/s, Train Acc=0.431, Train Loss=1.55]
Test Epoch 1/20: 100%|██████████| 156/156 [00:04<00:00, 31.55it/s, Test Acc=0.533, Test Loss=1.28]
Epoch [  1/ 20]  Train Loss:1.554635  Train Acc:43.09% Test Loss:1.281645  Test Acc:53.30%  Learning Rate:0.000100
Train Epoch 2/20: 100%|██████████| 781/781 [01:35<00:00,  8.20it/s, Train Acc=0.583, Train Loss=1.15]
Test Epoch 2/20: 100%|██████████| 156/156 [00:04<00:00, 31.34it/s, Test Acc=0.586, Test Loss=1.13]
Epoch [  2/ 20]  Train Loss:1.153438  Train Acc:58.34% Test Loss:1.130587  Test Acc:58.58%  Learning Rate:0.000100
Train Epoch 3/20: 100%|██████████| 781/781 [01:37<00:00,  8.04it/s, Train Acc=0.657, Train Loss=0.959]
Test Epoch 3/20: 100%|██████████| 156/156 [00:04<00:00, 31.20it/s, Test Acc=0.627, Test Loss=1.06]
Epoch [  3/ 20]  Train Loss:0.958730  Train Acc:65.69% Test Loss:1.057654  Test Acc:62.70%  Learning Rate:0.000100
Train Epoch 4/20: 100%|██████████| 781/781 [01:36<00:00,  8.10it/s, Train Acc=0.714, Train Loss=0.807]
Test Epoch 4/20: 100%|██████████| 156/156 [00:05<00:00, 31.00it/s, Test Acc=0.668, Test Loss=0.939]
Epoch [  4/ 20]  Train Loss:0.806689  Train Acc:71.37% Test Loss:0.939146  Test Acc:66.78%  Learning Rate:0.000100
Train Epoch 5/20: 100%|██████████| 781/781 [01:35<00:00,  8.14it/s, Train Acc=0.758, Train Loss=0.67] 
Test Epoch 5/20: 100%|██████████| 156/156 [00:04<00:00, 31.92it/s, Test Acc=0.672, Test Loss=0.962]
Epoch [  5/ 20]  Train Loss:0.670009  Train Acc:75.84% Test Loss:0.962207  Test Acc:67.16%  Learning Rate:0.000100
Train Epoch 6/20: 100%|██████████| 781/781 [01:32<00:00,  8.47it/s, Train Acc=0.809, Train Loss=0.54] 
Test Epoch 6/20: 100%|██████████| 156/156 [00:04<00:00, 31.60it/s, Test Acc=0.674, Test Loss=0.965]
Epoch [  6/ 20]  Train Loss:0.539732  Train Acc:80.90% Test Loss:0.965498  Test Acc:67.38%  Learning Rate:0.000100
Train Epoch 7/20: 100%|██████████| 781/781 [01:35<00:00,  8.15it/s, Train Acc=0.856, Train Loss=0.406]
Test Epoch 7/20: 100%|██████████| 156/156 [00:04<00:00, 31.80it/s, Test Acc=0.674, Test Loss=1.05]
Epoch [  7/ 20]  Train Loss:0.406309  Train Acc:85.64% Test Loss:1.050459  Test Acc:67.44%  Learning Rate:0.000100
Train Epoch 8/20: 100%|██████████| 781/781 [01:35<00:00,  8.21it/s, Train Acc=0.894, Train Loss=0.298]
Test Epoch 8/20: 100%|██████████| 156/156 [00:04<00:00, 31.61it/s, Test Acc=0.675, Test Loss=1.09]
Epoch [  8/ 20]  Train Loss:0.298231  Train Acc:89.37% Test Loss:1.093418  Test Acc:67.48%  Learning Rate:0.000100
Train Epoch 9/20: 100%|██████████| 781/781 [01:36<00:00,  8.07it/s, Train Acc=0.925, Train Loss=0.213]
Test Epoch 9/20: 100%|██████████| 156/156 [00:04<00:00, 31.56it/s, Test Acc=0.67, Test Loss=1.22] 
Epoch [  9/ 20]  Train Loss:0.212923  Train Acc:92.47% Test Loss:1.221283  Test Acc:67.04%  Learning Rate:0.000100
Train Epoch 10/20: 100%|██████████| 781/781 [01:35<00:00,  8.14it/s, Train Acc=0.942, Train Loss=0.169]
Test Epoch 10/20: 100%|██████████| 156/156 [00:04<00:00, 31.58it/s, Test Acc=0.667, Test Loss=1.33]
Epoch [ 10/ 20]  Train Loss:0.168813  Train Acc:94.20% Test Loss:1.330179  Test Acc:66.72%  Learning Rate:0.000100
Train Epoch 11/20: 100%|██████████| 781/781 [01:35<00:00,  8.18it/s, Train Acc=0.952, Train Loss=0.136]
Test Epoch 11/20: 100%|██████████| 156/156 [00:04<00:00, 31.45it/s, Test Acc=0.671, Test Loss=1.4] 
Epoch [ 11/ 20]  Train Loss:0.136068  Train Acc:95.25% Test Loss:1.398269  Test Acc:67.11%  Learning Rate:0.000100
Train Epoch 12/20: 100%|██████████| 781/781 [01:40<00:00,  7.79it/s, Train Acc=0.957, Train Loss=0.124] 
Test Epoch 12/20: 100%|██████████| 156/156 [00:05<00:00, 31.09it/s, Test Acc=0.666, Test Loss=1.48]
Epoch [ 12/ 20]  Train Loss:0.123808  Train Acc:95.68% Test Loss:1.483130  Test Acc:66.62%  Learning Rate:0.000100
Train Epoch 13/20: 100%|██████████| 781/781 [01:40<00:00,  7.75it/s, Train Acc=0.964, Train Loss=0.104] 
Test Epoch 13/20: 100%|██████████| 156/156 [00:04<00:00, 31.29it/s, Test Acc=0.664, Test Loss=1.6] 
Epoch [ 13/ 20]  Train Loss:0.104265  Train Acc:96.37% Test Loss:1.601849  Test Acc:66.41%  Learning Rate:0.000100
Train Epoch 14/20: 100%|██████████| 781/781 [01:39<00:00,  7.87it/s, Train Acc=0.964, Train Loss=0.101] 
Test Epoch 14/20: 100%|██████████| 156/156 [00:04<00:00, 31.86it/s, Test Acc=0.665, Test Loss=1.54]
Epoch [ 14/ 20]  Train Loss:0.101227  Train Acc:96.40% Test Loss:1.542168  Test Acc:66.47%  Learning Rate:0.000100
Train Epoch 15/20: 100%|██████████| 781/781 [01:38<00:00,  7.96it/s, Train Acc=0.971, Train Loss=0.085] 
Test Epoch 15/20: 100%|██████████| 156/156 [00:04<00:00, 31.65it/s, Test Acc=0.669, Test Loss=1.61]
Epoch [ 15/ 20]  Train Loss:0.084963  Train Acc:97.11% Test Loss:1.613661  Test Acc:66.88%  Learning Rate:0.000100
Train Epoch 16/20: 100%|██████████| 781/781 [01:36<00:00,  8.12it/s, Train Acc=0.969, Train Loss=0.0868]
Test Epoch 16/20: 100%|██████████| 156/156 [00:05<00:00, 30.98it/s, Test Acc=0.671, Test Loss=1.62]
Epoch [ 16/ 20]  Train Loss:0.086780  Train Acc:96.94% Test Loss:1.622345  Test Acc:67.06%  Learning Rate:0.000100
Train Epoch 17/20: 100%|██████████| 781/781 [01:30<00:00,  8.66it/s, Train Acc=0.973, Train Loss=0.0796]
Test Epoch 17/20: 100%|██████████| 156/156 [00:05<00:00, 30.96it/s, Test Acc=0.666, Test Loss=1.66]
Epoch [ 17/ 20]  Train Loss:0.079616  Train Acc:97.33% Test Loss:1.660496  Test Acc:66.59%  Learning Rate:0.000100
Train Epoch 18/20: 100%|██████████| 781/781 [01:35<00:00,  8.17it/s, Train Acc=0.973, Train Loss=0.0775]
Test Epoch 18/20: 100%|██████████| 156/156 [00:04<00:00, 31.67it/s, Test Acc=0.666, Test Loss=1.69]
Epoch [ 18/ 20]  Train Loss:0.077545  Train Acc:97.28% Test Loss:1.690717  Test Acc:66.57%  Learning Rate:0.000100
Train Epoch 19/20: 100%|██████████| 781/781 [01:35<00:00,  8.17it/s, Train Acc=0.972, Train Loss=0.0794]
Test Epoch 19/20: 100%|██████████| 156/156 [00:04<00:00, 31.58it/s, Test Acc=0.677, Test Loss=1.61]
Epoch [ 19/ 20]  Train Loss:0.079376  Train Acc:97.17% Test Loss:1.608011  Test Acc:67.70%  Learning Rate:0.000100
Train Epoch 20/20: 100%|██████████| 781/781 [01:31<00:00,  8.57it/s, Train Acc=0.976, Train Loss=0.0645]
Test Epoch 20/20: 100%|██████████| 156/156 [00:04<00:00, 31.74it/s, Test Acc=0.677, Test Loss=1.65]
Epoch [ 20/ 20]  Train Loss:0.064538  Train Acc:97.65% Test Loss:1.654503  Test Acc:67.70%  Learning Rate:0.000100

到这里有人看到可能会问,比如为什么Swin的结果似乎没有想象的那么好呢,是否有什么bug呢,我觉得可能有几个方面:

第一个是Transformer的形式需要稍微多一点的数据,我们在VIT的实验中也不能在短时间内超过ResNet的方法,所以我的建议是可以选取Swin-T去对224x224x3的图片进行操作。第二个是可以看到数据中,似乎训练集比测试集好很多,这第一部分可能是数据不够多,在不需要增加数据的情况下,可以去进行数据增强,这样也可以对结果有一个比较好的提升

训练曲线可视化

接着可以分别打印,损失函数曲线,准确率曲线和学习率曲线

plot_history(epoch ,Acc, Loss, Lr)

损失函数曲线


f0fc1baa9bb0447b97bae97b1534a62d.png

争取率曲线


b6295c0605c34ad3bf9483ebc08d4994.png

学习率曲线

210350179baa490e8b0b26b5e6044b6b.png

可以运行以下代码进行可视化

tensorboard --logdir logs

7.测试

查看准确率

correct = 0   # 定义预测正确的图片数,初始化为0
total = 0     # 总共参与测试的图片数,也初始化为0
# testloader = torch.utils.data.DataLoader(testset, batch_size=32,shuffle=True, num_workers=2)
for data in testloader:  # 循环每一个batch
    images, labels = data
    images = images.to(device)
    labels = labels.to(device)
    net.eval()  # 把模型转为test模式
    if hasattr(torch.cuda, 'empty_cache'):
        torch.cuda.empty_cache()
    outputs = net(images)  # 输入网络进行测试
    # outputs.data是一个4x10张量,将每一行的最大的那一列的值和序号各自组成一个一维张量返回,第一个是值的张量,第二个是序号的张量。
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)          # 更新测试图片的数量
    correct += (predicted == labels).sum() # 更新正确分类的图片的数量
print('Accuracy of the network on the 10000 test images: %.2f %%' % (100 * correct / total))
Accuracy of the network on the 10000 test images: 67.75 %

可以看到ShuffleNetv2的模型在测试集中准确率达到67.75%左右

程序中的 torch.max(outputs.data,1) ,返回一个tuple (元组)

而这里很明显,这个返回的元组的第一个元素是image data,即是最大的 值,第二个元素是label,即是最大的值 的 索引!我们只需要label (最大值的索引),所以就会有_,predicted这样的赋值语句,表示忽略第一个返回值,把它赋值给就是舍弃它的意思:

查看每一类的准确率

 # 定义2个存储每类中测试正确的个数的 列表,初始化为0
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
# testloader = torch.utils.data.DataLoader(testset, batch_size=64,shuffle=True, num_workers=2)
net.eval()
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        if hasattr(torch.cuda, 'empty_cache'):
            torch.cuda.empty_cache()
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
    #4组(batch_size)数据中,输出于label相同的,标记为1,否则为0
        c = (predicted == labels).squeeze()
        for i in range(len(images)):      # 因为每个batch都有4张图片,所以还需要一个4的小循环
            label = labels[i]   # 对各个类的进行各自累加
            class_correct[label] += c[i]
            class_total[label] += 1
for i in range(10):
    print('Accuracy of %5s : %.2f %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
Accuracy of airplane : 71.11 %
Accuracy of automobile : 69.37 %
Accuracy of  bird : 56.46 %
Accuracy of   cat : 51.41 %
Accuracy of  deer : 62.76 %
Accuracy of   dog : 60.12 %
Accuracy of  frog : 72.57 %
Accuracy of horse : 73.60 %
Accuracy of  ship : 84.18 %
Accuracy of truck : 75.45 %

抽样测试并可视化一部分结果

dataiter = iter(testloader)
images, labels = dataiter.next()
images_ = images
#images_ = images_.view(images.shape[0], -1)
images_ = images_.to(device)
labels = labels.to(device)
val_output = net(images_)
_, val_preds = torch.max(val_output, 1)
fig = plt.figure(figsize=(25,4))
correct = torch.sum(val_preds == labels.data).item()
val_preds = val_preds.cpu()
labels = labels.cpu()
print("Accuracy Rate = {}%".format(correct/len(images) * 100))
fig = plt.figure(figsize=(25,25))
for idx in np.arange(64):    
    ax = fig.add_subplot(8, 8, idx+1, xticks=[], yticks=[])
    #fig.tight_layout()
#     plt.imshow(im_convert(images[idx]))
    imshow(images[idx])
    ax.set_title("{}, ({})".format(classes[val_preds[idx].item()], classes[labels[idx].item()]), 
                 color = ("green" if val_preds[idx].item()==labels[idx].item() else "red"))
Accuracy Rate = 71.875%

352149df68da407592ec04d69152d6a1.png

8.保存模型

torch.save(net,save_path[:-4]+'_'+str(epoch)+'.pth')

9.预测

读取本地图片进行预测

import torch
from PIL import Image
from torch.autograd import Variable
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = swin_cifar(patch_size=2, n_classes=10,mlp_ratio=1)
model = torch.load(save_path)  # 加载模型
# model = model.to('cuda')
model.eval()  # 把模型转为test模式
# 读取要预测的图片
img = Image.open("./airplane.jpg").convert('RGB') # 读取图像

并且为了方便,定义了一个predict函数,简单思想就是,先resize成网络使用的shape,然后进行变化tensor输入即可,不过这里有一个点,我们需要对我们的图片也进行transforms,因为我们的训练的时候,对每个图像也是进行了transforms的,所以我们需要保持一致

def predict(img):
    trans = transforms.Compose([transforms.Resize((32,32)),
                            transforms.ToTensor(),
                            transforms.Normalize(mean=(0.5, 0.5, 0.5), 
                                                 std=(0.5, 0.5, 0.5)),
                           ])
    img = trans(img)
    img = img.to(device)
    # 图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]
    img = img.unsqueeze(0)  # 扩展后,为[1,3,32,32]
    output = model(img)
    prob = F.softmax(output,dim=1) #prob是10个分类的概率
    print("概率",prob)
    value, predicted = torch.max(output.data, 1)
    print("类别",predicted.item())
    print(value)
    pred_class = classes[predicted.item()]
    print("分类",pred_class)
# 读取要预测的图片
img = Image.open("./airplane.jpg").convert('RGB') # 读取图像
img

a4baf7e595b54cc1b3221e57b62ec43d.png

predict(img)
概率 tensor([[1.0000e+00, 2.2290e-09, 2.1466e-06, 1.7411e-11, 2.3164e-08, 3.5220e-10,
         1.2074e-10, 1.3403e-10, 1.8461e-07, 7.2670e-12]], device='cuda:0',
       grad_fn=<SoftmaxBackward0>)
类别 0
tensor([17.9435], device='cuda:0')
分类 plane

这里就可以看到,我们最后的结果,分类为plane,我们的置信率大概是100%,结果可以说是非常好,置信度很高,说明预测的还是比较准确的。

读取图片地址进行预测

我们也可以通过读取图片的url地址进行预测,这里我找了多个不同的图片进行预测

import requests
from PIL import Image
url = 'https://dss2.bdstatic.com/70cFvnSh_Q1YnxGkpoWK1HF6hhy/it/u=947072664,3925280208&fm=26&gp=0.jpg'
url = 'https://ss0.bdstatic.com/70cFuHSh_Q1YnxGkpoWK1HF6hhy/it/u=2952045457,215279295&fm=26&gp=0.jpg'
url = 'https://ss0.bdstatic.com/70cFvHSh_Q1YnxGkpoWK1HF6hhy/it/u=2838383012,1815030248&fm=26&gp=0.jpg'
url = 'https://gimg2.baidu.com/image_search/src=http%3A%2F%2Fwww.goupuzi.com%2Fnewatt%2FMon_1809%2F1_179223_7463b117c8a2c76.jpg&refer=http%3A%2F%2Fwww.goupuzi.com&app=2002&size=f9999,10000&q=a80&n=0&g=0n&fmt=jpeg?sec=1624346733&t=36ba18326a1e010737f530976201326d'
url = 'https://ss3.bdstatic.com/70cFv8Sh_Q1YnxGkpoWK1HF6hhy/it/u=2799543344,3604342295&fm=224&gp=0.jpg'
# url = 'https://ss1.bdstatic.com/70cFuXSh_Q1YnxGkpoWK1HF6hhy/it/u=2032505694,2851387785&fm=26&gp=0.jpg'
response = requests.get(url, stream=True)
print (response)
img = Image.open(response.raw)
img

c22f662ef70c4539bad1ad21ee8b702d.png

这里和前面是一样的

predict(img)
概率 tensor([[9.8081e-01, 6.6089e-08, 3.0682e-03, 1.3831e-04, 3.3311e-04, 5.0691e-03,
         3.5965e-07, 1.0807e-03, 9.4966e-03, 5.1127e-06]], device='cuda:0',
       grad_fn=<SoftmaxBackward0>)
类别 0
tensor([7.8766], device='cuda:0')
分类 plane

我们也看到,预测不正确了,预测的是飞机,但是实际上是猫,飞机的置信度有98%,说明这一部分来说并没有对真实图片有一个非常好的效果,也有可能对飞机有有一些偏置,不过我只迭代了20次如果加强迭代,也可能是图片有一些小问题,后续可能还需要继续增强咯。

10.总结

仔细想想Swin Transformer为什么那么好,在Swin Transformer之前的ViT和iGPT,它们都使用了小尺寸的图像作为输入,这种直接resize的策略无疑会损失很多信息。与它们不同的是,SwinTransformer的输入是图像的原始尺寸,例如ImageNet的224*224。另外Swin Transformer使用的是CNN中最常用的层次的网络结构,在CNN中一个特别重要的一点是随着网络层次的加深,节点的感受野也在不断扩大,这个特征在Swin Transformer中也是满足的。Swin Transformer的这种层次结构,也赋予了它可以像FPN,U-Net等结构实现可以进行分割或者检测的任务。

Swin-transformer是一种将卷积神经网络 (CNN)和Transformer相结合的模型。它通过使用局部patch的全连接注意力机制来替代CNN的卷积操作,从而实现了对CNN的优化。在模型设计上,Swin-transformer在多个方面都模仿了CNN的思想。例如,窗口机制类似于卷积核的局部相关性,层级结构的下采样类似于CNN的层级结构,窗口的移动操作类似于卷积的非重叠步长。

因此,尽管Transformer在计算机视觉领域取得了显著进展,但我们不应忽视卷积的重要性.Transformer的核心模块是自注意力机制,而CNN的核心模块是卷积操作。个人认为,在计算机视觉领域,自注意力不会取代卷积,而是应与卷积相结合,发挥各自的长处。以下是几个具体例子:

与文本不同,图像的维度更大。如果直接使用自注意力机制,计算量将变得非常大,这是我们不希望看到的。如果借鉴卷积的局部思想,可能可以缓解这个问题,或者干脆在网络的前部分直接使用卷积。例如,Vision Transformer (ViT) 将图像分成多个无重叠的patch,每个patch通过线性投影映射为patch embedding,这个过程实际上就是卷积的思想。

卷积有两个假设,即局部相关性和平移不变性。在计算机视觉领域,学术界认为卷积之所以如此有效,是因为这两个假设与图像的数据分布非常匹配。如果卷积的这两个假设与图像的数据分布真的非常匹配,那么可以将这两个假设与自注意力相结合,即自注意力不是针对全局,而是像卷积一样,在每个patch内进行自注意力计算。回想一下,这是否是Swin-transformer的思路?

顺带提一句,我们的数据和代码都在我的汇总篇里有说明,如果需要,可以自取

这里再贴一下汇总篇: 汇总篇

参考文献

Swin Transformer 详解

CV+Transformer之Swin Transformer

深度学习之图像分类 (十三) :Swin Transformer

图解Swin Transformer

相关文章
|
14天前
|
机器学习/深度学习 人工智能 PyTorch
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
本文探讨了Transformer模型中变长输入序列的优化策略,旨在解决深度学习中常见的计算效率问题。文章首先介绍了批处理变长输入的技术挑战,特别是填充方法导致的资源浪费。随后,提出了多种优化技术,包括动态填充、PyTorch NestedTensors、FlashAttention2和XFormers的memory_efficient_attention。这些技术通过减少冗余计算、优化内存管理和改进计算模式,显著提升了模型的性能。实验结果显示,使用FlashAttention2和无填充策略的组合可以将步骤时间减少至323毫秒,相比未优化版本提升了约2.5倍。
33 3
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
|
2月前
|
机器学习/深度学习 自然语言处理 数据建模
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
本文深入探讨了Transformer模型中的三种关键注意力机制:自注意力、交叉注意力和因果自注意力,这些机制是GPT-4、Llama等大型语言模型的核心。文章不仅讲解了理论概念,还通过Python和PyTorch从零开始实现这些机制,帮助读者深入理解其内部工作原理。自注意力机制通过整合上下文信息增强了输入嵌入,多头注意力则通过多个并行的注意力头捕捉不同类型的依赖关系。交叉注意力则允许模型在两个不同输入序列间传递信息,适用于机器翻译和图像描述等任务。因果自注意力确保模型在生成文本时仅考虑先前的上下文,适用于解码器风格的模型。通过本文的详细解析和代码实现,读者可以全面掌握这些机制的应用潜力。
112 3
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
|
6月前
|
机器学习/深度学习 自然语言处理 算法
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN
|
6月前
|
机器学习/深度学习 人工智能 PyTorch
PyTorch 图像篇
计算机视觉是多学科交叉的科技,属人工智能关键分支,应用于智能安防、自动驾驶、医疗和制造。技术包括物体检测、语义分割、运动跟踪等。早期依赖手工特征,但深度学习尤其是卷积神经网络(CNN)的发展改变了这一状况,CNN通过自动学习特征,改善了图像分类效率。CNN包含卷积层、池化层和全连接层,解决传统方法参数多、易丢失空间信息的问题。卷积操作在图像处理中用于特征提取,通过二维互相关运算学习图像特征。
|
7月前
|
机器学习/深度学习 JSON PyTorch
图神经网络入门示例:使用PyTorch Geometric 进行节点分类
本文介绍了如何使用PyTorch处理同构图数据进行节点分类。首先,数据集来自Facebook Large Page-Page Network,包含22,470个页面,分为四类,具有不同大小的特征向量。为训练神经网络,需创建PyTorch Data对象,涉及读取CSV和JSON文件,处理不一致的特征向量大小并进行归一化。接着,加载边数据以构建图。通过`Data`对象创建同构图,之后数据被分为70%训练集和30%测试集。训练了两种模型:MLP和GCN。GCN在测试集上实现了80%的准确率,优于MLP的46%,展示了利用图信息的优势。
107 1
|
5月前
|
机器学习/深度学习 PyTorch TensorFlow
在深度学习中,数据增强是一种常用的技术,用于通过增加训练数据的多样性来提高模型的泛化能力。`albumentations`是一个强大的Python库,用于图像增强,支持多种图像变换操作,并且可以与深度学习框架(如PyTorch、TensorFlow等)无缝集成。
在深度学习中,数据增强是一种常用的技术,用于通过增加训练数据的多样性来提高模型的泛化能力。`albumentations`是一个强大的Python库,用于图像增强,支持多种图像变换操作,并且可以与深度学习框架(如PyTorch、TensorFlow等)无缝集成。
|
5月前
|
PyTorch 算法框架/工具 索引
pytorch实现水果2分类(蓝莓,苹果)
pytorch实现水果2分类(蓝莓,苹果)
|
6月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】50.Pytorch_NLP项目实战:卷积神经网络textCNN在文本情感分类的运用
【从零开始学习深度学习】50.Pytorch_NLP项目实战:卷积神经网络textCNN在文本情感分类的运用
|
6月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
|
6月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】44. 图像增广的几种常用方式并使用图像增广训练模型【Pytorch】
【从零开始学习深度学习】44. 图像增广的几种常用方式并使用图像增广训练模型【Pytorch】

相关实验场景

更多