【Pytorch神经网络实战案例】12 利用注意力机制的神经网络实现对FashionMNIST数据集图片的分类

简介: 掩码模式:是相对于变长的循环序列而言的,如果输入的样本序列长度不同,那么会先对其进行对齐处理(对短序列补0,对长序列截断),再输入模型。这样,模型中的部分样本中就会有大量的零值。为了提升运算性能,需要以掩码的方式将不需要的零值去掉,并保留非零值进行计算,这就是掩码的作用

909c14a62daf432886fd9cb7803c09f2.png


1、掩码模式:是相对于变长的循环序列而言的,如果输入的样本序列长度不同,那么会先对其进行对齐处理(对短序列补0,对长序列截断),再输入模型。这样,模型中的部分样本中就会有大量的零值。为了提升运算性能,需要以掩码的方式将不需要的零值去掉,并保留非零值进行计算,这就是掩码的作用


2、均值模式:正常模式对每个维度的所有序列计算注意力分数,而均值模式对每个维度注意力分数计算平均值。均值模式会平滑处理同一序列不同维度之间的差异,认为所有维度都是平等的,将注意力用在序列之间。这种方式更能体现出序列的重要性。


2018121821364581.png


代码 Attention_cclassification.py


import torchvision
import torchvision.transforms as tranforms
import pylab
import torch
from matplotlib import pyplot as plt
import numpy as np
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'  # 可能是由于是MacOS系统的原因
data_dir = './fashion_mnist'
tranform = tranforms.Compose([tranforms.ToTensor()])
train_dataset = torchvision.datasets.FashionMNIST(root=data_dir,train=True,transform=tranform,download=True)
print("训练数据集条数",len(train_dataset))
val_dataset = torchvision.datasets.FashionMNIST(root=data_dir, train=False, transform=tranform)
print("测试数据集条数",len(val_dataset))
im = train_dataset[0][0]
im = im.reshape(-1,28)
pylab.imshow(im)
pylab.show()
print("该图片的标签为:",train_dataset[0][1])
## 数据集的制造
batch_size = 10
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
def imshow(img):
    print("图片形状:",np.shape(img))
    npimg = img.numpy()
    plt.axis('off')
    plt.imshow(np.transpose(npimg,(1,2,0)))
classes = ('T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle_Boot')
sample = iter(train_loader)
images,labels = sample.next()
print("样本形状:",np.shape(images))
print("样本标签",labels)
imshow(torchvision.utils.make_grid(images,nrow=batch_size))
print(','.join('%5s' % classes[labels[j]] for j in range(len(images))))
class myLSTMNet(torch.nn.Module): #定义myLSTMNet模型类,该模型包括 2个RNN层和1个全连接层
    def __init__(self,in_dim, hidden_dim, n_layer, n_class):
        super(myLSTMNet, self).__init__()
        # 定义循环神经网络层
        self.lstm = torch.nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True)
        self.Linear = torch.nn.Linear(hidden_dim * 28, n_class)  # 定义全连接层
        self.attention = AttentionSeq(hidden_dim, hard=0.03) # 定义注意力层,使用硬模式的注意力机制
    def forward(self, t):  # 搭建正向结构
        t, _ = self.lstm(t)  # 使用LSTM对象进行RNN数据处理
        t = self.attention(t)   # 对循环神经网络结果进行注意力机制的处理,将处理后的结果变形为二维数据,传入全连接输出层。1
        t = t.reshape(t.shape[0], -1) # 对循环神经网络结果进行注意力机制的处理,将处理后的结果变形为二维数据,传入全连接输出层。2
        out = self.Linear(t)  # 进行全连接处理
        return out
class AttentionSeq(torch.nn.Module):
    def __init__(self, hidden_dim, hard=0.0): # 初始化
        super(AttentionSeq, self).__init__()
        self.hidden_dim = hidden_dim
        self.dense = torch.nn.Linear(hidden_dim, hidden_dim)
        self.hard = hard
    def forward(self, features, mean=False): # 类的处理方法
        # [batch,seq,dim]
        batch_size, time_step, hidden_dim = features.size()
        weight = torch.nn.Tanh()(self.dense(features)) # 全连接计算
        # 计算掩码,mask给负无穷使得权重为0
        mask_idx = torch.sign(torch.abs(features).sum(dim=-1))
        # mask_idx = mask_idx.unsqueeze(-1).expand(batch_size, time_step, hidden_dim)
        mask_idx = mask_idx.unsqueeze(-1).repeat(1, 1, hidden_dim)
        # 将掩码作用在注意力结果上
        # torch.where函数的意思是按照第一参数的条件对每个元素进行检查,如果满足,那么使用第二个参数里对应元素的值进行填充,如果不满足,那么使用第三个参数里对应元素的值进行填充。
        # torch.ful_likeO函数是按照张量的形状进行指定值的填充,其第一个参数是参考形状的张量,第二个参数是填充值。
        weight = torch.where(mask_idx == 1, weight,torch.full_like(mask_idx, (-2 ** 32 + 1))) # 利用掩码对注意力结果补0序列填充一个极小数,会在Softmax中被忽略为0
        weight = weight.transpose(2, 1)
        # 必须对注意力结果补0序列填充一个极小数,千万不能填充0,因为注意力结果是经过激活函数tanh()计算出来的,其值域是 - 1~1, 在这个区间内,零值是一个有效值。如果填充0,那么会对后面的Softmax结果产生影响。填充的值只有远离这个有效区间才可以保证被Softmax的结果忽略。
        weight = torch.nn.Softmax(dim=2)(weight) # 计算注意力分数
        if self.hard != 0:  # hard mode
            weight = torch.where(weight > self.hard, weight, torch.full_like(weight, 0))
        if mean: # 支持注意力分数平均值模式
            weight = weight.mean(dim=1)
            weight = weight.unsqueeze(1)
            weight = weight.repeat(1, hidden_dim, 1)
        weight = weight.transpose(2, 1)
        features_attention = weight * features # 将注意力分数作用于特征向量上
        return features_attention # 返回结果
#实例化模型对象
network = myLSTMNet(28, 128, 2, 10)  # 图片大小是28x28,28:输入数据的序列长度为28。128:每层放置128个LSTM Cell。2:构建两层由LSTM Cell所组成的网络。10:最终结果分为10类。
#指定设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
network.to(device)
print(network)  #打印网络
criterion = torch.nn.CrossEntropyLoss()  # 实例化损失函数类
optimizer = torch.optim.Adam(network.parameters(), lr=0.01)
for epoch in range(2): # 数据集迭代2次
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0): # 循环取出批次数据
        inputs, labels = data
        inputs = inputs.squeeze(1) # 由于输入数据是序列形式,不再是图片,因此将通道设为1
        inputs, labels = inputs.to(device), labels.to(device) # 指定设备
        optimizer.zero_grad() # 清空之前的梯度
        outputs = network(inputs)
        loss = criterion(outputs, labels) # 计算损失
        loss.backward()  #反向传播
        optimizer.step() #更新参数
        running_loss += loss.item()
        if i % 1000 == 999:
            print('[%d, %5d] loss: %.3f' %
                (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0
print('Finished Training')
#使用模型
dataiter = iter(test_loader)
images, labels = dataiter.next()
inputs, labels = images.to(device), labels.to(device)
imshow(torchvision.utils.make_grid(images,nrow=batch_size))
print('真实标签: ', ' '.join('%5s' % classes[labels[j]] for j in range(len(images))))
inputs = inputs.squeeze(1)
outputs = network(inputs)
_, predicted = torch.max(outputs, 1)
print('预测结果: ', ' '.join('%5s' % classes[predicted[j]]
                              for j in range(len(images))))
#测试模型
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images = images.squeeze(1)
        inputs, labels = images.to(device), labels.to(device)
        outputs = network(inputs)
        _, predicted = torch.max(outputs, 1)
        predicted = predicted.to(device)
        c = (predicted == labels).squeeze()
        for i in range(10):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1
sumacc = 0
for i in range(10):
    Accuracy = 100 * class_correct[i] / class_total[i]
    print('Accuracy of %5s : %2d %%' % (classes[i], Accuracy ))
    sumacc =sumacc+Accuracy
print('Accuracy of all : %2d %%' % ( sumacc/10. ))


目录
相关文章
|
4月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 中的动态计算图:实现灵活的神经网络架构
【8月更文第27天】PyTorch 是一款流行的深度学习框架,它以其灵活性和易用性而闻名。与 TensorFlow 等其他框架相比,PyTorch 最大的特点之一是支持动态计算图。这意味着开发者可以在运行时定义网络结构,这为构建复杂的模型提供了极大的便利。本文将深入探讨 PyTorch 中动态计算图的工作原理,并通过一些示例代码展示如何利用这一特性来构建灵活的神经网络架构。
331 1
|
4月前
|
机器学习/深度学习 人工智能 PyTorch
【深度学习】使用PyTorch构建神经网络:深度学习实战指南
PyTorch是一个开源的Python机器学习库,特别专注于深度学习领域。它由Facebook的AI研究团队开发并维护,因其灵活的架构、动态计算图以及在科研和工业界的广泛支持而受到青睐。PyTorch提供了强大的GPU加速能力,使得在处理大规模数据集和复杂模型时效率极高。
199 59
|
3月前
|
机器学习/深度学习
小土堆-pytorch-神经网络-损失函数与反向传播_笔记
在使用损失函数时,关键在于匹配输入和输出形状。例如,在L1Loss中,输入形状中的N代表批量大小。以下是具体示例:对于相同形状的输入和目标张量,L1Loss默认计算差值并求平均;此外,均方误差(MSE)也是常用损失函数。实战中,损失函数用于计算模型输出与真实标签间的差距,并通过反向传播更新模型参数。
|
3月前
|
机器学习/深度学习 数据挖掘 TensorFlow
解锁Python数据分析新技能,TensorFlow&PyTorch双引擎驱动深度学习实战盛宴
在数据驱动时代,Python凭借简洁的语法和强大的库支持,成为数据分析与机器学习的首选语言。Pandas和NumPy是Python数据分析的基础,前者提供高效的数据处理工具,后者则支持科学计算。TensorFlow与PyTorch作为深度学习领域的两大框架,助力数据科学家构建复杂神经网络,挖掘数据深层价值。通过Python打下的坚实基础,结合TensorFlow和PyTorch的强大功能,我们能在数据科学领域探索无限可能,解决复杂问题并推动科研进步。
69 0
|
4月前
|
机器学习/深度学习 PyTorch TensorFlow
【PyTorch】PyTorch深度学习框架实战(一):实现你的第一个DNN网络
【PyTorch】PyTorch深度学习框架实战(一):实现你的第一个DNN网络
188 1
|
4月前
|
机器学习/深度学习 PyTorch 测试技术
深度学习入门:使用 PyTorch 构建和训练你的第一个神经网络
【8月更文第29天】深度学习是机器学习的一个分支,它利用多层非线性处理单元(即神经网络)来解决复杂的模式识别问题。PyTorch 是一个强大的深度学习框架,它提供了灵活的 API 和动态计算图,非常适合初学者和研究者使用。
56 0
|
11天前
|
存储 SQL 安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
随着互联网的普及,网络安全问题日益突出。本文将介绍网络安全的重要性,分析常见的网络安全漏洞及其危害,探讨加密技术在保障网络安全中的作用,并强调提高安全意识的必要性。通过本文的学习,读者将了解网络安全的基本概念和应对策略,提升个人和组织的网络安全防护能力。
|
12天前
|
SQL 安全 网络安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
随着互联网的普及,网络安全问题日益突出。本文将从网络安全漏洞、加密技术和安全意识三个方面进行探讨,旨在提高读者对网络安全的认识和防范能力。通过分析常见的网络安全漏洞,介绍加密技术的基本原理和应用,以及强调安全意识的重要性,帮助读者更好地保护自己的网络信息安全。
35 10
|
13天前
|
SQL 安全 网络安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
在数字化时代,网络安全和信息安全已成为我们生活中不可或缺的一部分。本文将介绍网络安全漏洞、加密技术和安全意识等方面的内容,并提供一些实用的代码示例。通过阅读本文,您将了解到如何保护自己的网络安全,以及如何提高自己的信息安全意识。
43 10
|
13天前
|
存储 监控 安全
云计算与网络安全:云服务、网络安全、信息安全等技术领域的融合与挑战
本文将探讨云计算与网络安全之间的关系,以及它们在云服务、网络安全和信息安全等技术领域中的融合与挑战。我们将分析云计算的优势和风险,以及如何通过网络安全措施来保护数据和应用程序。我们还将讨论如何确保云服务的可用性和可靠性,以及如何处理网络攻击和数据泄露等问题。最后,我们将提供一些关于如何在云计算环境中实现网络安全的建议和最佳实践。

热门文章

最新文章