使用 **迭代器** 获取 Cifar 等常用数据集

简介: 一个很方便使用的数据库。

CifarMNIST 等常用数据集的坑:

  • 每次在一台新的机器上使用它们去训练模型都需要重新下载(国内网络往往都不给力,需要花费大量的时间,有时还下载不了);
  • 即使下载到本地,然而不同的模型对它们的处理方式各不相同,我们又需要花费一些时间去了解如何读取数据。

为了解决上述的坑,我在Bunch 转换为 HDF5 文件:高效存储 Cifar 等数据集中将一些常用的数据集封装为 HDF5 文件。

下面的 X.h5c 可以参考Bunch 转换为 HDF5 文件:高效存储 Cifar 等数据集自己制作,也可以直接下载使用(链接:https://pan.baidu.com/s/1hsbMhv3MDlOES3UDDmOQiw 密码:qlb7)。

使用方法很简单:

访问数据集

# 载入所需要的包
import tables as tb
import numpy as np
xpath = 'E:/xdata/X.h5'  # 文件所在路径
h5 = tb.open_file(xpath)

下面我们来看看此文件中有那些数据集:

h5.root
/ (RootGroup) "Xinet's dataset"
  children := ['cifar10' (Group), 'cifar100' (Group), 'fashion_mnist' (Group), 'mnist' (Group)]

下面我们以 Cifar 为例,来详细说明该文件的使用:

cifar = h5.root.cifar100   # 获取 cifar100

为了高效使用数据集,我们使用迭代器的方式来获取它:

class Loader:
    """
    方法
    ========
    L 为该类的实例
    len(L)::返回 batch 的批数
    iter(L)::即为数据迭代器

    Return
    ========
    可迭代对象(numpy 对象)
    """

    def __init__(self, X, Y, batch_size, shuffle):
        '''
        X, Y 均为类 numpy 
        '''
        self.X = X
        self.Y = Y
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self):
        n = len(self.X)
        idx = np.arange(n)

        if self.shuffle:
            np.random.shuffle(idx)

        for k in range(0, n, self.batch_size):
            K = idx[k:min(k + self.batch_size, n)].tolist()
            yield np.take(self.X, K, 0), np.take(self.Y, K, 0)

    def __len__(self):
        return round(len(self.X) / self.batch_size)

下面我们可以使用 Loader 来实例化我们的数据集:

batch_size = 512
train_cifar = Loader(cifar.trainX, cifar.train_fine_labels, batch_size, True)
test_cifar = Loader(cifar.testX, cifar.test_fine_labels, batch_size, False)

读取一个 Batch 的数据:

for imgs, labels in iter(train_cifar):
    break
names = np.asanyarray([cifar.fine_label_names[label] for label in labels], dtype='U')
names[:7]
array(['orchid', 'spider', 'rabbit', 'shark', 'shrew', 'clock', 'bed'],
      dtype='<U13')

可视化

需要注意,这里的 Cifarfirst channel 的,即:

imgs.shape
(512, 3, 32, 32)
names.shape
(512,)
from pylab import plt, mpl


mpl.rcParams['font.sans-serif'] = ['SimHei']  # 指定默认字体
mpl.rcParams['axes.unicode_minus'] = False  # 解决保存图像是负号 '-' 显示为方块的问题


def show_imgs(imgs, labels):
    '''
    展示 多张图片
    '''
    imgs = np.transpose(imgs, (0, 2, 3, 1))
    n = imgs.shape[0]
    h, w = 5, int(n / 5)
    fig, ax = plt.subplots(h, w, figsize=(7, 7))
    K = np.arange(n).reshape((h, w))
    names = np.asanyarray([cifar.fine_label_names[label] for label in labels], dtype='U')
    names = names.reshape((h, w))
    for i in range(h):
        for j in range(w):
            img = imgs[K[i, j]]
            ax[i][j].imshow(img)
            ax[i][j].axes.get_yaxis().set_visible(False)
            ax[i][j].axes.set_xlabel(names[i][j])
            ax[i][j].set_xticks([])
    plt.show()
show_imgs(imgs[:25], labels[:25])

output_19_0.png-89.9kB

$2$ 个深度学习框架 & 数据集

因为,上面的数据集是 NumPyarray 形式,故而:

TensorFlow

import tensorflow as tf
for imgs, labels in iter(train_cifar):
    imgs = tf.constant(imgs)
    labels = tf.constant(labels)
    break
imgs
<tf.Tensor 'Const:0' shape=(512, 3, 32, 32) dtype=uint8>
labels
<tf.Tensor 'Const_1:0' shape=(512,) dtype=int32>

MXNet

from mxnet import nd, cpu, gpu
for imgs, labels in iter(train_cifar):
    imgs = nd.array(imgs, ctx = gpu(0))
    labels = nd.array(labels, ctx = cpu(0))
    break
imgs.context
gpu(0)
labels.context
cpu(0)

Matlab 读取 HDF

参考:h5read
捕获.PNG-65.5kB

目录
相关文章
|
5月前
|
数据可视化 Python Windows
Matplotlib输出中文显示的2种解决方案
Matplotlib输出中文显示的2种解决方案
335 0
|
5月前
|
存储 人工智能 数据可视化
AI计算机视觉笔记二十一:PaddleOCR训练自定义数据集
在完成PaddleOCR环境搭建与测试后,本文档详细介绍如何训练自定义的车牌检测模型。首先,在`PaddleOCR`目录下创建`train_data`文件夹存放数据集,并下载并解压缩车牌数据集。接着,复制并修改配置文件`ch_det_mv3_db_v2.0.yml`以适应训练需求,包括设置模型存储目录、训练可视化选项及数据集路径。随后,下载预训练权重文件并放置于`pretrain_models`目录下,以便进行预测与训练。最后,通过指定命令行参数执行训练、断点续训、测试及导出推理模型等操作。
|
9月前
|
Python
Fastapi进阶用法,路径参数,路由分发,查询参数等详解
Fastapi进阶用法,路径参数,路由分发,查询参数等详解
569 1
|
7月前
|
机器学习/深度学习 Python
【Python】已解决:ModuleNotFoundError: No module named ‘paddle’
【Python】已解决:ModuleNotFoundError: No module named ‘paddle’
987 1
|
6月前
|
API Docker Windows
2024 Ollama 一站式解决在Windows系统安装、使用、定制服务与实战案例
这篇文章是一份关于Ollama工具的一站式使用指南,涵盖了在Windows系统上安装、使用和定制服务,以及实战案例。
2024 Ollama 一站式解决在Windows系统安装、使用、定制服务与实战案例
|
3天前
|
人工智能 自然语言处理 Shell
深度评测 | 仅用3分钟,百炼调用满血版 Deepseek-r1 API,百万Token免费用,简直不要太爽。
仅用3分钟,百炼调用满血版Deepseek-r1 API,享受百万免费Token。阿里云提供零门槛、快速部署的解决方案,支持云控制台和Cloud Shell两种方式,操作简便。Deepseek-r1满血版在推理能力上表现出色,尤其擅长数学、代码和自然语言处理任务,使用过程中无卡顿,体验丝滑。结合Chatbox工具,用户可轻松掌控模型,提升工作效率。阿里云大模型服务平台百炼不仅速度快,还确保数据安全,值得信赖。
157353 24
深度评测 | 仅用3分钟,百炼调用满血版 Deepseek-r1 API,百万Token免费用,简直不要太爽。
|
5天前
|
人工智能 API 网络安全
用DeepSeek,就在阿里云!四种方式助您快速使用 DeepSeek-R1 满血版!更有内部实战指导!
DeepSeek自发布以来,凭借卓越的技术性能和开源策略迅速吸引了全球关注。DeepSeek-R1作为系列中的佼佼者,在多个基准测试中超越现有顶尖模型,展现了强大的推理能力。然而,由于其爆火及受到黑客攻击,官网使用受限,影响用户体验。为解决这一问题,阿里云提供了多种解决方案。
16978 37
|
13天前
|
机器学习/深度学习 人工智能 自然语言处理
PAI Model Gallery 支持云上一键部署 DeepSeek-V3、DeepSeek-R1 系列模型
DeepSeek 系列模型以其卓越性能在全球范围内备受瞩目,多次评测中表现优异,性能接近甚至超越国际顶尖闭源模型(如OpenAI的GPT-4、Claude-3.5-Sonnet等)。企业用户和开发者可使用 PAI 平台一键部署 DeepSeek 系列模型,实现 DeepSeek 系列模型与现有业务的高效融合。
|
5天前
|
并行计算 PyTorch 算法框架/工具
本地部署DeepSeek模型
要在本地部署DeepSeek模型,需准备Linux(推荐Ubuntu 20.04+)或兼容的Windows/macOS环境,配备NVIDIA GPU(建议RTX 3060+)。安装Python 3.8+、PyTorch/TensorFlow等依赖,并通过官方渠道下载模型文件。配置模型后,编写推理脚本进行测试,可选使用FastAPI服务化部署或Docker容器化。注意资源监控和许可协议。
1310 8
|
13天前
|
人工智能 搜索推荐 Docker
手把手教你使用 Ollama 和 LobeChat 快速本地部署 DeepSeek R1 模型,创建个性化 AI 助手
DeepSeek R1 + LobeChat + Ollama:快速本地部署模型,创建个性化 AI 助手
3416 117
手把手教你使用 Ollama 和 LobeChat 快速本地部署 DeepSeek R1 模型,创建个性化 AI 助手

热门文章

最新文章