计算机视觉PyTorch迁移学习 - (二)

本文涉及的产品
函数计算FC,每月15万CU 3个月
简介: 计算机视觉PyTorch迁移学习 - (二)

3.PyTorch实现迁移学习


文件目录


59ba9e08e90b4868a3618f8403acc9ea.png


3.1数据集预处理


这里实现一个蚂蚁与蜜蜂的图像分类,用到的数据集data下载


dataset.py


from torchvision import datasets, transforms
import torch
train=transforms.Compose([
    transforms.RandomResizedCrop(224),  # 随机裁剪一个area然后再resize
    transforms.RandomHorizontalFlip(),  # 随机水平翻转
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val=transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
trainset=datasets.ImageFolder(root='hymenoptera_data/train',transform=train)
valset=datasets.ImageFolder(root='hymenoptera_data/val',transform=val)
trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,
                                        shuffle=True, num_workers=4)
valloader=torch.utils.data.DataLoader(valset,batch_size=4,
                                      shuffle=True, num_workers=4)

3.2构建模型


model.py


from torchvision import models
import torch.nn as nn
#初始化模型
#保证模型不改变的层的参数,不发生梯度变化
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False
def initialize_model(model_name, num_classes, feature_extract):
    model_ft=None
    input_size=0
    if model_name =='resnet':
        #resnet18
        model_ft = models.resnet18(pretrained=True)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224
    elif model_name == "alexnet":
        model_ft = models.alexnet(pretrained=True)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        input_size = 224
    elif model_name == "vgg":
        #vgg11
        model_ft = models.vgg11_bn(pretrained=True)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        input_size = 224
    elif model_name == "squeezenet":
        model_ft = models.squeezenet1_0(pretrained=True)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
        model_ft.num_classes = num_classes
        input_size = 224
    elif model_name == "densenet":
        model_ft = models.densenet121(pretrained=True)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224
    elif model_name == "inception":
        model_ft = models.inception_v3(pretrained=True)
        set_parameter_requires_grad(model_ft, feature_extract)
        # Handle the auxilary net
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 299
    else:
        print("没有合适的模型...")
    return model_ft, input_size


3.3模型训练与验证


run.py


from __future__ import print_function
from __future__ import division
import torch.nn as nn
import torch.optim as optim
from model import initialize_model
from torch.optim import lr_scheduler
import time
import copy
from dataset import *
import argparse
parser=argparse.ArgumentParser()
#模型选择
parser.add_argument('-m','--model_name',type=str,choices=['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception'],help="input model_name",default='resnet')
#分类类别数
parser.add_argument('-n','--num_classes',type=int,help="input num_classes",default=2)
#定义一个批次的样本数
parser.add_argument('-b','--batch_size',type=int,help="input batch_size",default=8)
#定义迭代批次
parser.add_argument('-e','--num_epochs',type=int,help="input num_epochs",default=25)
args=parser.parse_args()
#用于特征提取的标志。如果为False,则对整个模型进行微调,
#如果为True,则仅更新重塑的图层参数
feature_extract = True
#定义数据字典
datasets={train:trainset,val:valset}
#定义数据集字典
dataloaders={train:trainloader,val:valloader}
model_ft, input_size = initialize_model(args.model_name, args.num_classes, feature_extract)
criterion = nn.CrossEntropyLoss()
# 观察所有参数都正在优化
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
# 每7个epochs衰减LR通过设置gamma=0.1
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
def train_model(model,criterion,optimizer,scheduler,num_epochs):
    since=time.time()
    val_acc_history = []
    #获取模型初始参数
    best_model_wts=copy.deepcopy(model.state_dict())
    best_acc=0.0
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch,num_epochs-1))
        print('-'*10)
        for data in ['train','val']:
            if data=='train':
                scheduler.step()
                model.train()
            else:
                model.eval()
            running_loss = 0.0
            running_corrects = 0
            for inputs,labels in dataloaders[data]:
                optimizer.zero_grad()
                with torch.set_grad_enabled(data=='train'):
                    outputs=model(inputs)
                    _,preds=torch.max(outputs,1)
                    loss=criterion(outputs,labels)
                    if data=='train':
                        loss.backward()
                        optimizer.step()
                running_loss+=loss.item()*inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                epoch_loss = running_loss / len(datasets[data])
                epoch_acc = running_corrects.double() / len(datasets[data])
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                    data, epoch_loss, epoch_acc))
                # 深度复制mo
                if data=='val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
            print()
            time_elapsed = time.time() - since
            print('Training complete in {:.0f}m {:.0f}s'.format(
                time_elapsed // 60, time_elapsed % 60))
            print('Best val Acc: {:4f}'.format(best_acc))
            model.load_state_dict(best_model_wts)
            return model
train_model(model_ft,criterion, optimizer_ft, exp_lr_scheduler,args.num_epochs)

e152d2bf069445c4a3af1d6b835b2419.png



相关实践学习
【文生图】一键部署Stable Diffusion基于函数计算
本实验教你如何在函数计算FC上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。函数计算提供一定的免费额度供用户使用。本实验答疑钉钉群:29290019867
建立 Serverless 思维
本课程包括: Serverless 应用引擎的概念, 为开发者带来的实际价值, 以及让您了解常见的 Serverless 架构模式
相关文章
|
2月前
|
PyTorch Linux 算法框架/工具
pytorch学习一:Anaconda下载、安装、配置环境变量。anaconda创建多版本python环境。安装 pytorch。
这篇文章是关于如何使用Anaconda进行Python环境管理,包括下载、安装、配置环境变量、创建多版本Python环境、安装PyTorch以及使用Jupyter Notebook的详细指南。
329 1
pytorch学习一:Anaconda下载、安装、配置环境变量。anaconda创建多版本python环境。安装 pytorch。
|
6月前
|
机器学习/深度学习 自然语言处理 算法
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN
|
2月前
|
机器学习/深度学习 缓存 PyTorch
pytorch学习一(扩展篇):miniconda下载、安装、配置环境变量。miniconda创建多版本python环境。整理常用命令(亲测ok)
这篇文章是关于如何下载、安装和配置Miniconda,以及如何使用Miniconda创建和管理Python环境的详细指南。
515 0
pytorch学习一(扩展篇):miniconda下载、安装、配置环境变量。miniconda创建多版本python环境。整理常用命令(亲测ok)
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
聊一聊计算机视觉中常用的注意力机制以及Pytorch代码实现
本文介绍了几种常用的计算机视觉注意力机制及其PyTorch实现,包括SENet、CBAM、BAM、ECA-Net、SA-Net、Polarized Self-Attention、Spatial Group-wise Enhance和Coordinate Attention等,每种方法都附有详细的网络结构说明和实验结果分析。通过这些注意力机制的应用,可以有效提升模型在目标检测任务上的性能。此外,作者还提供了实验数据集的基本情况及baseline模型的选择与实验结果,方便读者理解和复现。
67 0
聊一聊计算机视觉中常用的注意力机制以及Pytorch代码实现
|
6月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】43. 算法优化之Adam算法【RMSProp算法与动量法的结合】介绍及其Pytorch实现
【从零开始学习深度学习】43. 算法优化之Adam算法【RMSProp算法与动量法的结合】介绍及其Pytorch实现
|
2月前
|
机器学习/深度学习 人工智能 TensorFlow
浅谈计算机视觉新手的学习路径
浅谈计算机视觉新手的学习路径
24 0
|
4月前
|
存储 PyTorch API
Pytorch入门—Tensors张量的学习
Pytorch入门—Tensors张量的学习
34 0
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】47. Pytorch图片样式迁移实战:将一张图片样式迁移至另一张图片,创作自己喜欢风格的图片【含完整源码】
【从零开始学习深度学习】47. Pytorch图片样式迁移实战:将一张图片样式迁移至另一张图片,创作自己喜欢风格的图片【含完整源码】
|
6月前
|
机器学习/深度学习 资源调度 PyTorch
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】
|
6月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】50.Pytorch_NLP项目实战:卷积神经网络textCNN在文本情感分类的运用
【从零开始学习深度学习】50.Pytorch_NLP项目实战:卷积神经网络textCNN在文本情感分类的运用
下一篇
DataWorks