数据集学习笔记(三):COCO创建dataloader用于训练

简介: 如何使用COCO数据集创建dataloader进行训练,包括安装环境、加载数据集代码、定义数据转换、创建数据集对象以及创建dataloader。

安装环境

# 安装Cython
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple Cython

# 安装pycocotools
pip install git+https://github.com/philferriere/cocoapi.git#subdirectory=PythonAPI

如果下面这个安装不起,就通过这个下载
然后pip install 即可

COCO训练数据加载代码

import torchvision.datasets as datasets
from torchvision import transforms
import torch,cv2
"""""""""""""""""""""""""""""""""COCO dataloader"""""""""""""""""""""""""""""
train_root = 'E:/dataset/Aquarium/train/'
val_root = 'E:/dataset/Aquarium/valid/'
font = cv2.FONT_HERSHEY_SIMPLEX
train_annFile = 'E:/dataset/Aquarium/annotations/train_annotations.coco.json'
val_annFile = 'E:/dataset/Aquarium/annotations/val_annotations.coco.json'
# 定义 coco collate_fn
def collate_fn_coco(batch):
    return tuple(zip(*batch))

# 创建 coco dataset
data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),  # 随机裁剪,在缩放成224*224
                                 # transforms.RandomErasing(p=0.5), # 随机遮挡 概率0.5
                                 transforms.RandomHorizontalFlip(),  # 水平方向随机翻转,概率为0.5
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
    "val": transforms.Compose([transforms.Resize(256),
                               transforms.CenterCrop(224),
                               transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
coco_train = datasets.CocoDetection(train_root, train_annFile, transform=data_transform["train"])
coco_val = datasets.CocoDetection(val_root, val_annFile, transform=data_transform["val"])

# 创建 dataloader
train_loader = torch.utils.data.DataLoader(coco_train, batch_size=16 ,shuffle=True,num_workers=0,pin_memory=True,collate_fn=collate_fn_coco,drop_last=True)
val_loader = torch.utils.data.DataLoader(coco_val, batch_size=16 ,shuffle=False,num_workers=0,pin_memory=True,collate_fn=collate_fn_coco,drop_last=False)

# 可视化
for imgs, labels in train_loader:
    for i in range(len(imgs)):
        bboxes = []
        ids = []
        img = imgs[i]
        labels_ = labels[i]
        for label in labels_:
            bboxes.append([label['bbox'][0],
                           label['bbox'][1],
                           label['bbox'][0] + label['bbox'][2],
                           label['bbox'][1] + label['bbox'][3]
                           ])
            ids.append(label['category_id'])

        img = img.permute(1, 2, 0).numpy()
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        for box, id_ in zip(bboxes, ids):
            x1 = int(box[0])
            y1 = int(box[1])
            x2 = int(box[2])
            y2 = int(box[3])
            cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), thickness=2)
            cv2.putText(img, text=str(id_), org=(x1 + 5, y1 + 5), fontFace=font, fontScale=1,
                        thickness=2, lineType=cv2.LINE_AA, color=(0, 255, 0))
        cv2.imshow('test', img)
        cv2.waitKey()
目录
相关文章
|
算法 数据库 计算机视觉
Dataset之COCO数据集:COCO数据集的简介、下载、使用方法之详细攻略
Dataset之COCO数据集:COCO数据集的简介、下载、使用方法之详细攻略
|
机器学习/深度学习 PyTorch 算法框架/工具
【单点知识】基于实例详解PyTorch中的DataLoader类
【单点知识】基于实例详解PyTorch中的DataLoader类
1886 2
|
并行计算 算法 计算机视觉
【MATLAB 】 CEEMD 信号分解+模糊熵(近似熵)算法
【MATLAB 】 CEEMD 信号分解+模糊熵(近似熵)算法
422 0
|
数据处理 计算机视觉 Python
【目标检测】指定划分COCO数据集训练(车类,行人类,狗类...)
【目标检测】指定划分COCO数据集训练(车类,行人类,狗类...)
5721 0
|
Shell Linux Python
基于远程服务器安装配置Anaconda环境及创建python虚拟环境详细方案(一)
基于远程服务器安装配置Anaconda环境及创建python虚拟环境详细方案
7318 0
基于远程服务器安装配置Anaconda环境及创建python虚拟环境详细方案(一)
|
XML JSON 数据可视化
数据集学习笔记(二): 转换不同类型的数据集用于模型训练(XML、VOC、YOLO、COCO、JSON、PNG)
本文详细介绍了不同数据集格式之间的转换方法,包括YOLO、VOC、COCO、JSON、TXT和PNG等格式,以及如何可视化验证数据集。
3322 1
数据集学习笔记(二): 转换不同类型的数据集用于模型训练(XML、VOC、YOLO、COCO、JSON、PNG)
|
Python
Python 3.5 RuntimeError: can't start new thread
/*********************************************************************** * Python 3.5 RuntimeError: can't start new thread * 说明: * 测试的时候线程开得太多了,导致软件开始,不再能够被处理,卡死。
6884 0
|
机器学习/深度学习 编解码 Java
YOLO11创新改进系列:卷积,主干 注意力,C3k2融合,检测头等创新机制(已更新100+)
《YOLO11目标检测创新改进与实战案例》专栏已更新100+篇文章,涵盖注意力机制、卷积优化、检测头创新、损失与IOU优化、轻量级网络设计等多方面内容。每周更新3-10篇,提供详细代码和实战案例,帮助您掌握最新研究和实用技巧。[专栏链接](https://blog.csdn.net/shangyanaf/category_12810477.html)
YOLO11创新改进系列:卷积,主干 注意力,C3k2融合,检测头等创新机制(已更新100+)
|
机器学习/深度学习 并行计算 PyTorch
从零开始下载torch+cu(无痛版)
这篇文章提供了一个详细的无痛版教程,指导如何从零开始下载并配置支持CUDA的PyTorch GPU版本,包括查看Cuda版本、在官网检索下载包名、下载指定的torch、torchvision、torchaudio库,并在深度学习环境中安装和测试是否成功。
从零开始下载torch+cu(无痛版)
|
机器学习/深度学习 人工智能 文字识别
ultralytics YOLO11 全新发布!(原理介绍+代码详见+结构框图)
本文详细介绍YOLO11,包括其全新特性、代码实现及结构框图,并提供如何使用NEU-DET数据集进行训练的指南。YOLO11在前代基础上引入了新功能和改进,如C3k2、C2PSA模块和更轻量级的分类检测头,显著提升了模型的性能和灵活性。文中还对比了YOLO11与YOLOv8的区别,并展示了训练过程和结果的可视化
19849 0