计算机视觉PyTorch - 数据处理(库数据和训练自己的数据)

简介: 计算机视觉PyTorch - 数据处理(库数据和训练自己的数据)

1. pytorch库自带数据


为了更好的理解,这里以CIFAR10数据集作为训练和测试数据集。

我们将使用CIFAR10数据集,它包含十个类别:

[‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’]。

CIFAR-10 中的图像尺寸为3x32x32,也就是RGB的3层颜色

通道,每层通道内的尺寸为32x32。


74e21ccf600342c5af116c6dcd7c4613.png


数据预处理


😃CIFAR10数据集的输出是范围在[0,1]之间的 PILImage,即对每个类别的概率分布情况。所以我们需要通过ToTensor()把图像灰度范围从(0-255)变换到(0-1)之间,并通过transform.Normalize()把(0-1)变换到(-1,1)


import torch
import torchvision
import torchvision.transforms as transforms
#定义三个通道的像素值 均值(mean)为0.5,方差(std)为0.5
transform = transforms.Compose(
            [transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])


数据生成


torchvision.datasets中包含了以下数据集

  • MNIST
  • COCO(用于图像标注和目标检测)(Captioning and Detection)
  • LSUN Classification
  • ImageFolder
  • Imagenet-12
  • CIFAR10 and CIFAR100
  • STL10


数据生成函数


class torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)


参数说明


  • root:保存数据集的目录
  • train:True= 训练集, False = 测试集
  • download:True = 从互联网上下载数据集,并把数据集放在root目录下. 如果数据集之前下载过,就不用再重复下载。
  • transform:对数据集预处理的函数


trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data',train=False,download=True, transform=transform)


数据加载


数据加载函数


class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)


参数说明


dataset (Dataset):加载数据的数据集。

batch_size (int, optional):每个batch加载多少个样本(默认: 1)。

shuffle (bool, optional):设置为True时会在每个epoch重新打乱数据(默认: False).

sampler (Sampler, optional):定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。

num_workers (int, optional):用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)


trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False, num_workers=2)


2. 训练自己的数据


由于pytorch库中的数据集包含的种类比较匮乏,我们在实际的应用中往往还会对其他的事物做图像分类,因此需要自己的数据集图像来训练,实现图像分类。


生成数据集


要想用自己的数据集进行图像分类或者其他计算机视觉应用,不是之前下载好图片,进行训练就行了🤣

首先第一步需要自己的图像数据集进行标注


标注图像需要用到标注工具,这里介绍一种最方便的:labelimg


安装labelimg,只需要在终端运行


pip install labelimg


之后在终端运行如下代码,即可开始对图像进行标注


(base) MacBook-Air ~ % labelimg
• 1

7a0f234a3c6241aeb955c763ca8a71f8.png


进行批量标注

点击打开文件按钮可以打开需要被标注的图片的文件夹。

点击改变存放目录按钮可以打开标注文件存放的文件夹。

点击w快捷键可以开始标注,标注完后需要保存


71eb637c16ad45d28d6110fd3224299a.png


最后标注完成的图像,会生成一个标注文件xml格式。


数据预处理


接下来就是对标注后的图像进行预处理。

首先创建一个文件夹(这里按照官方的文件夹名字命名😂)

19ac003d590044aeaa888b692b7d59ce.png


Annotations:存放标注xml文件

JPEGImages:存放图片

ImageSets:存放一个名为Main文件夹,Main文件夹用来存放后续生成的train.txt,val.txt,test.txt、trainval.txt(也可以只有train.txt和test.txt,根据个人需求看是否需要验证集),这些文件保存的内容为图片的名字(没有后缀格式)

src:存放后续生成的train.txt,val.txt,test.txt、trainval.txt,但这里的的文件内容是,对应每个图片的绝对路径+类别

label:存放不同图像的标注文件(感觉这个文件没有用😂)


生成Main里的文件


import os
import random 
random.seed(0)
xmlfilepath='Annotations'
saveBasePath="ImageSets/Main/"
trainval_percent=1
train_percent=1
temp_xml = os.listdir(xmlfilepath)
total_xml = []
for xml in temp_xml:
    if xml.endswith(".xml"):
        total_xml.append(xml)
num=len(total_xml)  
list=range(num)  
tv=int(num*trainval_percent)  
tr=int(tv*train_percent)  
trainval= random.sample(list,tv)  
train=random.sample(trainval,tr)  
print("train and val size",tv)
print("traub suze",tr)
ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w')  
ftest = open(os.path.join(saveBasePath,'test.txt'), 'w')  
ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w')  
fval = open(os.path.join(saveBasePath,'val.txt'), 'w')  
for i  in list:  
    name=total_xml[i][:-4]+'\n'  
    if i in trainval:  
        ftrainval.write(name)  
        if i in train:  
            ftrain.write(name)  
        else:  
            fval.write(name)  
    else:  
        ftest.write(name)  
ftrainval.close()  
ftrain.close()  
fval.close()  
ftest .close()


生成src里的文件


import xml.etree.ElementTree as ET
from os import getcwd
sets=['train','val','test','trainval']
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
def convert_annotation(image_id, list_file):
    in_file = open('Annotations/%s.xml'%(image_id), encoding='utf-8')
    tree=ET.parse(in_file)
    root = tree.getroot()
    for obj in root.iter('object'):
        difficult = 0 
        if obj.find('difficult')!=None:
            difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult)==1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
        list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
wd = getcwd()
for image_set in sets:
    image_ids = open('ImageSets/Main/%s.txt'%(image_set), encoding='utf-8').read().strip().split()
    list_file = open('src/%s.txt'%(image_set), 'w', encoding='utf-8')
    for image_id in image_ids:
        list_file.write('JPEGImages/%s.jpg'%(image_id))
        #这里写入的是图片的绝对路径
        convert_annotation(image_id, list_file)
        list_file.write('\n')
    list_file.close()


数据加载


from PIL import Image
import torch
import torchvision.transforms as transforms
class MyDataset(torch.utils.data.Dataset):  # 创类:MyDataset,继承torch.utils.data.Dataset
    def __init__(self, datatxt, transform=None):
        super(MyDataset, self).__init__()
        fh = open(datatxt, 'r')  # 打开src中的txt文件,读取内容
        imgs = []
        for line in fh:  # 按行循环txt文本中的内容
            line = line.rstrip()  # 删除本行string字符串末尾的指定字符
            words = line.split()  # 通过指定分隔符对字符串进行切片,默认为所有的空字符,包括空格、换行、制表符等
            imgs.append((words[0], int(words[1])))  # 把txt里的内容读入imgs列表保存,words[0]是图片信息,words[1]是label
        self.imgs = imgs
        self.transform = transform
    def __getitem__(self, index):  # 按照索引读取每个元素的具体内容
        fn, label = self.imgs[index]  # fn是图片path
        img = Image.open(fn).convert('RGB')  # from PIL import Image
        if self.transform is not None:  # 是否进行transform
            img = self.transform(img)
        return img, label  # return回哪些内容,在训练时循环读取每个batch,就能获得哪些内容
    def __len__(self):  # 它返回的是数据集的长度,必须有
        return len(self.imgs)
'''标准化、图片变换'''
mean = [0.5071, 0.4867, 0.4408]
stdv = [0.2675, 0.2565, 0.2761]
train_transforms = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=stdv)])
train_data = MyDataset(datatxt='train.txt', transform=train_transforms)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)



相关文章
|
1天前
|
数据采集 PyTorch 数据处理
PyTorch的数据处理
PyTorch中,`Dataset`封装自定义数据集,`DataLoader`负责批量加载和多线程读取。例如,定义一个简单的`Dataset`类,包含数据和标签,然后使用`DataLoader`指定批大小和工作线程数。数据预处理包括导入如Excel的数据,图像数据集可通过`torchvision.datasets`加载。示例展示了如何从Excel文件创建`Dataset`,并用`DataLoader`读取。
|
10天前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】44. 图像增广的几种常用方式并使用图像增广训练模型【Pytorch】
【从零开始学习深度学习】44. 图像增广的几种常用方式并使用图像增广训练模型【Pytorch】
|
10天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】36. 门控循环神经网络之长短期记忆网络(LSTM)介绍、Pytorch实现LSTM并进行训练预测
【从零开始学习深度学习】36. 门控循环神经网络之长短期记忆网络(LSTM)介绍、Pytorch实现LSTM并进行训练预测
|
10天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】35. 门控循环神经网络之门控循环单元(gated recurrent unit,GRU)介绍、Pytorch实现GRU并进行训练预测
【从零开始学习深度学习】35. 门控循环神经网络之门控循环单元(gated recurrent unit,GRU)介绍、Pytorch实现GRU并进行训练预测
|
11天前
|
机器学习/深度学习 并行计算 PyTorch
使用PyTorch Profiler进行模型性能分析,改善并加速PyTorch训练
加速机器学习模型训练是工程师的关键需求。PyTorch Profiler提供了一种分析工具,用于测量CPU和CUDA时间,以及内存使用情况。通过在训练代码中嵌入分析器并使用tensorboard查看结果,工程师可以识别性能瓶颈。Profiler的`record_function`功能允许为特定操作命名,便于跟踪。优化策略包括使用FlashAttention或FSDP减少内存使用,以及通过torch.compile提升速度。监控CUDA内核执行和内存分配,尤其是避免频繁的cudaMalloc,能有效提升GPU效率。内存历史记录分析有助于检测内存泄漏和优化批处理大小。
29 1
|
27天前
|
机器学习/深度学习 监控 算法
开源计算机视觉库OpenCV详解
开源计算机视觉库OpenCV详解
|
30天前
|
机器学习/深度学习 PyTorch 算法框架/工具
使用FP8加速PyTorch训练的两种方法总结
在PyTorch中,FP8数据类型用于高效训练和推理,旨在减少内存占用和加快计算速度。虽然官方尚未全面支持,但在2.2版本中引入了`torch.float8_e4m3fn`和`torch.float8_e5m2`。文章通过示例展示了如何利用FP8优化Vision Transformer模型,使用Transformer Engine库提升性能,并探讨了PyTorch原生FP8支持的初步使用方法。实验表明,结合TE和FP8,训练速度可提升3倍,性能有显著增强,特别是在NVIDIA GPU上。然而,PyTorch的FP8支持仍处于试验阶段,可能带来不稳定性。
58 0
|
1月前
|
PyTorch 数据处理 算法框架/工具
pytorch 数据处理备忘
pytorch 数据处理备忘
19 1
|
1月前
|
机器学习/深度学习 负载均衡 PyTorch
PyTorch分布式训练:加速大规模数据集的处理
【4月更文挑战第18天】PyTorch分布式训练加速大规模数据集处理,通过数据并行和模型并行提升训练效率。`torch.distributed`提供底层IPC与同步,适合定制化需求;`DistributedDataParallel`则简化并行过程。实际应用注意数据划分、通信开销、负载均衡及错误处理。借助PyTorch分布式工具,可高效应对深度学习的计算挑战,未来潜力无限。
|
16小时前
|
编解码 机器人 测试技术
2024年6月计算机视觉论文推荐:扩散模型、视觉语言模型、视频生成等
6月还有一周就要结束了,我们今天来总结2024年6月上半月发表的最重要的论文,重点介绍了计算机视觉领域的最新研究和进展。
19 8