图像分类(迁移学习/五分钟手把手教你搭建分类模型)上

本文涉及的产品
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
交互式建模 PAI-DSW,每月250计算时 3个月
模型训练 PAI-DLC,100CU*H 3个月
简介: 图像分类(迁移学习/五分钟手把手教你搭建分类模型)

目录


数据加载及处理

模型训练

网络搭建

分类结果

模型保存


前言


迁移学习(图像分类)


在本教程中,您将学习如何使用迁移学习训练卷积神经网络以进行图像分类。您可以在 cs231n 上阅读有关迁移学习的更多信息。

本文主要目的是教会你如何自己搭建分类模型,耐心看完,相信会有很大收获。废话不多说,直切主题…

首先们要知道深度学习大都包含了下面几个方面:

1.加载(处理)数据
2.网络搭建
3.损失函数(模型优化)
4 模型训练和保存

把握好这些主要内容和流程,基本上对分类模型就大致有了个概念。


正文


数据加载及处理


我们今天要解决的问题是训练一个模型来对蚂蚁和蜜蜂进行分类。我们有大约120张蚂蚁和蜜蜂的训练图像。每个类有 75 个验证图像。通常,这是一个非常小的数据集,如果从头开始训练,则可以对其进行概括。由于我们使用的是迁移学习,我们应该能够很好地泛化。 加载数据代码部分:

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}
data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


其中:

data_dir = 'data/hymenoptera_data'


在训练你自己的数据的时候,你需要改为自己的数据路径,并且你需要把你的分类数据按照下面示例进行放置。

假设你要进行猫和狗分类,你需要如下:

c65cd413e0b63afed71fd4134f143ea5_fa08ba90c56f42a4bf897d811d9d70c3.png


并且把每一个种类都分为训练集(train)和验证集(val)如下:

f0103cc3a0978108ce30aee107243333_8491bf9b791c472587b861407911deba.png


猫的数据也同上进行放置,一般train里面的数据要远多于val文件夹。

本文的数据集可视化图像和代码如下:

fbcfd607c65b26b4c188085134f5293e_f1dd227ce8364633941d7e9ecda93774.png


可视化代码

def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated
# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))
# Make a grid from batch
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])


模型训练


训练模型的代码示例如下所示:

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
            running_loss = 0.0
            running_corrects = 0
            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                # zero the parameter gradients
                optimizer.zero_grad()
                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                # statistics
                running_loss += loss.item() * inputs.size(0)
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            # deep copy the model
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')
    # load best model weights
    model.load_state_dict(best_model_wts)
    return model


需要你重点关注的是下面这部分代码,我们知道深度学习要学习的目标是缩小给的标签与模型训练值的差距,这里你要是训练狗和猫两类,那么这两类就是dog和cat。loss是重点,该函数通过预测与真实值的做差来使得模型学习我们希望他学习的参数。


特别注意的是.to(device)是GPU加速使用。在模型预测时则不需要使用GPU,就这样模型出来一个值与真实标签值做差的过程中训练了一个良好的分类效果。

 for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                # zero the parameter gradients
                optimizer.zero_grad()
                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)


相关文章
|
UED
生命中的开关:深入探讨Deactivated和Activated生命周期
生命中的开关:深入探讨Deactivated和Activated生命周期
364 1
|
机器学习/深度学习 自然语言处理 PyTorch
【PyTorch实战演练】基于AlexNet的预训练模型介绍
【PyTorch实战演练】基于AlexNet的预训练模型介绍
653 0
|
19天前
|
人工智能 Cloud Native 数据可视化
PyCharm 2025.1 完整教程:下载安装 + 中文设置 + 激活,一步到位,附安装包
PyCharm 2025.1 发布,重磅升级AI代码补全、类型推断与ruff集成,提升开发效率。支持渐进式补全、智能提交信息生成、冲突可视化解决,优化启动速度与内存占用,全面增强云原生及现代Python开发体验。
403 5
|
19天前
|
Android开发
占用CPU和内存过大
使用的软件是Android Studio。占用CPU和内存很大,严重影响使用。
|
4月前
|
机器学习/深度学习 人工智能 PyTorch
AI 基础知识从 0.2 到 0.3——构建你的第一个深度学习模型
本文以 MNIST 手写数字识别为切入点,介绍了深度学习的基本原理与实现流程,帮助读者建立起对神经网络建模过程的系统性理解。
597 15
AI 基础知识从 0.2 到 0.3——构建你的第一个深度学习模型
|
机器学习/深度学习 自然语言处理 异构计算
预训练与微调
预训练与微调
953 5
|
人工智能 自动驾驶 数据安全/隐私保护
人工智能的伦理困境:我们如何确保AI的道德发展?
【10月更文挑战第21天】随着人工智能(AI)技术的飞速发展,其在各行各业的应用日益广泛,从而引发了关于AI伦理和道德问题的讨论。本文将探讨AI伦理的核心问题,分析当前面临的挑战,并提出确保AI道德发展的建议措施。
|
存储 机器学习/深度学习 SQL
【Prompt Engineering:自我反思(Reflexion)】
自我反思(Reflexion)是一种通过语言反馈强化基于语言的智能体的新范式,无需微调模型即可提升其在决策、推理和编程等任务中的表现。该框架包括参与者(生成动作)、评估者(评分)和自我反思(生成反馈)三个部分,利用大语言模型生成具体反馈,帮助智能体从错误中快速学习,显著提高了多种任务的性能。
1408 2
【Prompt Engineering:自我反思(Reflexion)】
|
JavaScript 前端开发 API
深入探索挖掘vue3 生命周期
【10月更文挑战第10天】
261 0
|
缓存 NoSQL 网络协议
【Azure Redis 缓存】Redisson 连接 Azure Redis出现间歇性 java.net.UnknownHostException 异常
【Azure Redis 缓存】Redisson 连接 Azure Redis出现间歇性 java.net.UnknownHostException 异常
419 1

热门文章

最新文章

下一篇
oss云网关配置