# 使用 **迭代器** 获取 Cifar 等常用数据集

CifarMNIST 等常用数据集的坑：

• 每次在一台新的机器上使用它们去训练模型都需要重新下载（国内网络往往都不给力，需要花费大量的时间，有时还下载不了）；
• 即使下载到本地，然而不同的模型对它们的处理方式各不相同，我们又需要花费一些时间去了解如何读取数据。

## 访问数据集

# 载入所需要的包
import tables as tb
import numpy as np
xpath = 'E:/xdata/X.h5'  # 文件所在路径
h5 = tb.open_file(xpath)

h5.root
/ (RootGroup) "Xinet's dataset"
children := ['cifar10' (Group), 'cifar100' (Group), 'fashion_mnist' (Group), 'mnist' (Group)]


cifar = h5.root.cifar100   # 获取 cifar100

class Loader:
"""
方法
========
L 为该类的实例
len(L)::返回 batch 的批数
iter(L)::即为数据迭代器

Return
========
可迭代对象（numpy 对象）
"""

def __init__(self, X, Y, batch_size, shuffle):
'''
X, Y 均为类 numpy
'''
self.X = X
self.Y = Y
self.batch_size = batch_size
self.shuffle = shuffle

def __iter__(self):
n = len(self.X)
idx = np.arange(n)

if self.shuffle:
np.random.shuffle(idx)

for k in range(0, n, self.batch_size):
K = idx[k:min(k + self.batch_size, n)].tolist()
yield np.take(self.X, K, 0), np.take(self.Y, K, 0)

def __len__(self):
return round(len(self.X) / self.batch_size)

batch_size = 512
train_cifar = Loader(cifar.trainX, cifar.train_fine_labels, batch_size, True)
test_cifar = Loader(cifar.testX, cifar.test_fine_labels, batch_size, False)

for imgs, labels in iter(train_cifar):
break
names = np.asanyarray([cifar.fine_label_names[label] for label in labels], dtype='U')
names[:7]
array(['orchid', 'spider', 'rabbit', 'shark', 'shrew', 'clock', 'bed'],
dtype='<U13')


## 可视化

imgs.shape
(512, 3, 32, 32)

names.shape
(512,)

from pylab import plt, mpl

mpl.rcParams['font.sans-serif'] = ['SimHei']  # 指定默认字体
mpl.rcParams['axes.unicode_minus'] = False  # 解决保存图像是负号 '-' 显示为方块的问题

def show_imgs(imgs, labels):
'''
展示 多张图片
'''
imgs = np.transpose(imgs, (0, 2, 3, 1))
n = imgs.shape[0]
h, w = 5, int(n / 5)
fig, ax = plt.subplots(h, w, figsize=(7, 7))
K = np.arange(n).reshape((h, w))
names = np.asanyarray([cifar.fine_label_names[label] for label in labels], dtype='U')
names = names.reshape((h, w))
for i in range(h):
for j in range(w):
img = imgs[K[i, j]]
ax[i][j].imshow(img)
ax[i][j].axes.get_yaxis().set_visible(False)
ax[i][j].axes.set_xlabel(names[i][j])
ax[i][j].set_xticks([])
plt.show()
show_imgs(imgs[:25], labels[:25])

## $2$ 个深度学习框架 & 数据集

### TensorFlow

import tensorflow as tf
for imgs, labels in iter(train_cifar):
imgs = tf.constant(imgs)
labels = tf.constant(labels)
break
imgs
<tf.Tensor 'Const:0' shape=(512, 3, 32, 32) dtype=uint8>

labels
<tf.Tensor 'Const_1:0' shape=(512,) dtype=int32>


### MXNet

from mxnet import nd, cpu, gpu
for imgs, labels in iter(train_cifar):
imgs = nd.array(imgs, ctx = gpu(0))
labels = nd.array(labels, ctx = cpu(0))
break
imgs.context
gpu(0)

labels.context
cpu(0)


## Matlab 读取 HDF

|
6天前
|
XML 机器学习/深度学习 数据格式
YOLOv8训练自己的数据集+常用传参说明
YOLOv8训练自己的数据集+常用传参说明
1404 0
|
6天前
|

42 0
|
8月前
|

213 0
|
9月前
|

160 0
|
10月前
|

YOLOV7详细解读（四）训练自己的数据集
YOLOV7详细解读（四）训练自己的数据集
601 0
|
11月前
|

149 0
PASCAL VOC数据集分割为小样本数据集代码
PASCAL VOC数据集分割为小样本数据集代码
124 0

519 0
|

【目标检测之数据集预处理】继承Dataset定义自己的数据集【附代码】（下）

163 0
|

【目标检测之数据集预处理】继承Dataset定义自己的数据集【附代码】（上）

268 0