利用Pytorch实现一个完整的基于深度学习的人脸表情识别项目

本文涉及的产品
视觉智能开放平台,图像资源包5000点
视觉智能开放平台,分割抠图1万点
视觉智能开放平台,视频资源包5000点
简介: 利用Pytorch实现一个完整的基于深度学习的人脸表情识别项目

9c6c95759f9de96d3bc0754b71cef88f.jpg该任务基于图像分类网络Alex实现。


✨1 train脚本

从设备,数据集,模型,优化器,损失函数,进度条,模型评估和参数保存等方面进行总结说明。

🌭1.1 设备

cpu或者单卡gpu

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

多卡gpu待补充…

🍕1.2 数据集

这部分包含数据增强,重载DataSet类,DataLoader打包三项操作,下面一一介绍:

🎆 1.2.1 图像增强

表情识别属于分类任务,数据预处理比较简单:

  1. ToTensor将数据转化为Tensor数据。
  2. RandomResizedCrop将图像裁剪到224的大小(这是网络要求的)。
  3. RandomHorizontalFlip增强图像的泛化性。
  4. Normalize归一化使得数据的分布更加均匀,减少模型学到数据分布的可能性。
data_transform = {
    "train": transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    "val": transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

需要注意的是,cv2打开的图像数据类型是numpy,不能进行ToTensor之外的图像增强操作。因此,该步必须放在第一个。

🍔1.2.2 MMAFEDB表情识别数据集介绍

百度网盘(5pi5)

22ec56274cc948818acd20892e938541.png

下载数据集并解压后,内容如下:

  1. labels.txt:标签种类
  2. train_list.txtval_list.txtval_list.txt:分别是训练数据,验证数据和测试数据,每条内容未图像路径 标签
  3. 各文件夹名称即分类标签,内部是该分类的图像数据。

🌭1.2.3 重载DataSet

点击此处进入之前总结过的自定义数据集的总结

处理MMAFEDB的详细代码见第二节。导入数据集代码为:

train_dataset = MMAFEDB(root_path, is_type="train", transform=data_transform["train"])
val_dataset = MMAFEDB(root_path, is_type="eval", transform=data_transform["val"])

其中root_pathtxt文件的父路径。is_type可选参数,为"train"或"eval"或"test",即导入的是什么数据。transform是图像增强操作。

🎃1.2.4 DataLoader打包

假设已经创建DataSet的重载类,即数据集导入完成。打包操作为:

# 打包
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=512,
    shuffle=True,
    # num_workers=nw,
)
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=512,
    shuffle=True
)

更具体的参数见1.2.3链接第二节

🎄1.3 模型

分类模型有很多:Alex,GooleNet,ResNet,MobileNet…这里选用AlexNet,其它后续也会进行尝试,待补充…

🎈1.4 优化器和损失函数

损失函数的一些总结

优化器待补充…梯度下降算法推荐看刘建平老师的博客

这里是使用了Pytorch包装好的Adam优化器和交叉熵损失函数

optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)  # 总结
loss_function = torch.nn.CrossEntropyLoss()

✨1.5 模型及参数的加载和保存

🍕1.5.1 模型的加载和保存

保存模型用到torch.save(model, save_dir)函数,其中model是自定义的模型对象,save_dir是保存路径:

save_dir = ""  # 保存路径,自定义
torch.save(model, save_dir)

而加载该模型应该是:

torch.load(save_path)  # save_path是保存的模型的路径

🎆1.5.2 权重的加载和保存

保存权重分为两步:获取权重和保存参数:

  1. model.state_dict()获取参数:
paramters = model.state_dict()
  1. torch.save保存
torch.save(paramters, save_path)  # save_path是保存路径,自定义

加载模型,仍然用torch.load

paramters = torch.load(save_path)  # save_path是权重文件的保存路径

只是后面,我们需要load_state_dict将参数赋予模型

model.load_state_dict(paramters)

🍔1.6 模型评估

这里先简单采用正确率

    # 验证部分
    model.eval()
    acc = 0.0
    best_acc = 0.0
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            img, label = data
            output = model(img.to(device))
            pred = torch.max(output, dim=1)[1]
            acc += torch.eq(pred, label.to(device)).sum().item()  # TODO 1 累加batch个中预测和标签一致的数量
    acc = acc / len(val_dataset)  # TODO 2 所有数据acc累加除所有数据的数量
    print("acc: {}".format(acc))
    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), "./weights/best.pth")

代码中两行注释即正确率的计算方法。

需要注意的是,len(val_dataset)即可得到所有数据的数量,在其它任务中肯定会用到。

✨2 重载DataSet(代码)

from torch.utils.data import Dataset
import os
import cv2
from torchvision import transforms
import torch
class MMAFEDB(Dataset):
    def __init__(self, path: str, is_type: str, transform=None):
        """
        :param path: Parent path of the dataset
        :param transform:
        """
        assert os.path.exists(path), "no path:{}".format(path)
        self.path = path
        self.type = is_type
        self.transform = transform
        self.img_path = []
        self.label = []
        self.load_path()
    def __len__(self):
        return len(self.img_path)
    def __getitem__(self, index):
        img_path = self.img_path[index]
        label = self.label[index]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.transform is not None:
            img = self.transform(img)
        label = torch.as_tensor(int(label))
        return img, label
    def load_path(self):
        try:
            if self.type == "train":
                with open(os.path.join(self.path, "train_list.txt"), "r", encoding="utf-8") as file:
                    lines = file.readlines()
            elif self.type == "eval":
                with open(os.path.join(self.path, "val_list.txt"), "r", encoding="utf-8") as file:
                    lines = file.readlines()
            elif self.type == "test":
                with open(os.path.join(self.path, "test_list.txt"), "r", encoding="utf-8") as file:
                    lines = file.readlines()
        except FileExistsError as e:
            print(e)
        for line in lines:
            img_path, cl = line.split()
            img_path = os.path.join(self.path, img_path)
            self.img_path.append(img_path)
            self.label.append(cl)
if __name__ == "__main__":
    data_transform = {
        "train": transforms.Compose([
            transforms.ToTensor(),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        "val": "",
    }
    dataset = MMAFEDB("E:/DataSet/MMAFEDB", is_type="train", transform=data_transform["train"])
    for data in dataset:
        print(2)

if __name__ == "__main__"中是用以测试的代码,首要关注MMAFEDB类:

  1. load_path函数将txt文件中的图片数据路径和label提取出来,存入self.img_pathself.label

5cb75b015cfb43a195b5d1d9afe68c74.png__getitem__函数通过index,提取一张图像和对应label。然后进行图像增强操作(img = self.transform(img)),通过label = torch.as_tensor(int(label))将label的数据类型转化为Tensor。这里主要用as_tensor而不是to_tensor,主要是因为to_tensor会将数据除255,as_tensor数据保持不变。因此to_Tensor用于image,给image进行归一化。as_Tensor用于label,保持原有标签。

✨ 4 一些注意点

  1. 模型to设备时不用多赋予一次值,即不用model = model.to(device),只需要执行model.to(device)即可。
  2. 训练前优化器权重要清零optimizer.zero_grad()
  3. 利用model和loss_functional计算式,设备必须统一,数据类型也必须是张量。
  4. model.train()model.eval()作用是开启/关闭dropout和BN操作,如果没有使用与否没有区别。
  5. with torch.no_grad()关闭自动求导,验证时必须开启。


相关文章
|
2天前
|
机器学习/深度学习 人工智能 PyTorch
【深度学习】使用PyTorch构建神经网络:深度学习实战指南
PyTorch是一个开源的Python机器学习库,特别专注于深度学习领域。它由Facebook的AI研究团队开发并维护,因其灵活的架构、动态计算图以及在科研和工业界的广泛支持而受到青睐。PyTorch提供了强大的GPU加速能力,使得在处理大规模数据集和复杂模型时效率极高。
112 59
|
2天前
|
机器学习/深度学习 人工智能 自然语言处理
【深度学习】AudioLM音频生成模型概述及应用场景,项目实践及案例分析
AudioLM(Audio Language Model)是一种基于深度学习的音频生成模型,它使用自回归或变分自回归的方法来生成连续的音频信号。这类模型通常建立在Transformer架构或者类似的序列到序列(Seq2Seq)框架上,通过学习大量音频数据中的统计规律,能够生成具有高保真度和创造性的音频片段。AudioLM模型不仅能够合成音乐、语音,还能生成自然界的声音、环境噪声等,其应用广泛,涵盖了娱乐、教育、辅助技术、内容创作等多个领域。
9 1
|
6天前
|
机器学习/深度学习 PyTorch TensorFlow
【PyTorch】PyTorch深度学习框架实战(一):实现你的第一个DNN网络
【PyTorch】PyTorch深度学习框架实战(一):实现你的第一个DNN网络
29 1
|
17天前
|
机器学习/深度学习 人工智能 PyTorch
【Deepin 20深度探索】一键解锁Linux深度学习潜能:从零开始安装Pytorch,驾驭AI未来从Deepin出发!
【8月更文挑战第2天】随着人工智能的迅猛发展,深度学习框架Pytorch已成为科研与工业界的必备工具。Deepin 20作为优秀的国产Linux发行版,凭借其流畅的用户体验和丰富的软件生态,为深度学习爱好者提供理想开发平台。本文引导您在Deepin 20上安装Pytorch,享受Linux下的深度学习之旅。
39 12
|
13天前
|
机器学习/深度学习 存储 PyTorch
【深度学习】Pytorch面试题:什么是 PyTorch?PyTorch 的基本要素是什么?Conv1d、Conv2d 和 Conv3d 有什么区别?
关于PyTorch面试题的总结,包括PyTorch的定义、基本要素、张量概念、抽象级别、张量与矩阵的区别、不同损失函数的作用以及Conv1d、Conv2d和Conv3d的区别和反向传播的解释。
36 2
|
13天前
|
机器学习/深度学习 算法 PyTorch
【深度学习】TensorFlow面试题:什么是TensorFlow?你对张量了解多少?TensorFlow有什么优势?TensorFlow比PyTorch有什么不同?该如何选择?
关于TensorFlow面试题的总结,涵盖了TensorFlow的基本概念、张量的理解、TensorFlow的优势、数据加载方式、算法通用步骤、过拟合解决方法,以及TensorFlow与PyTorch的区别和选择建议。
33 2
|
20天前
|
机器学习/深度学习 数据挖掘 TensorFlow
解锁Python数据分析新技能,TensorFlow&PyTorch双引擎驱动深度学习实战盛宴
【7月更文挑战第31天】在数据驱动时代,Python凭借其简洁性与强大的库支持,成为数据分析与机器学习的首选语言。**数据分析基础**从Pandas和NumPy开始,Pandas简化了数据处理和清洗,NumPy支持高效的数学运算。例如,加载并清洗CSV数据、计算总销售额等。
33 2
|
20天前
|
机器学习/深度学习 数据挖掘 TensorFlow
|
6天前
|
机器学习/深度学习 人工智能 PyTorch
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
20 1
|
1月前
|
机器学习/深度学习 算法 PyTorch
使用Pytorch中从头实现去噪扩散概率模型(DDPM)
在本文中,我们将构建基础的无条件扩散模型,即去噪扩散概率模型(DDPM)。从探究算法的直观工作原理开始,然后在PyTorch中从头构建它。本文主要关注算法背后的思想和具体实现细节。
8620 3

热门文章

最新文章