TensorFlow加载cifar10数据集

简介: TensorFlow加载cifar10数据集

加载cifar10数据集

cifar10_dir = 'C:/Users/1/.keras/datasets/cifar-10-batches-py'
(train_images, train_labels), (test_images, test_labels) = load_data(cifar10_dir)

注意:在官网下好cifar10数据集后将其解压成下面形式

load_local_cifar10.py

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
from six.moves import cPickle
from tensorflow.keras import backend as K
def load_batch(fpath, label_key='labels'):
    """Internal utility for parsing CIFAR data.
    # Arguments
        fpath: path the file to parse.
        label_key: key for label data in the retrieve
            dictionary.
    # Returns
        A tuple `(data, labels)`.
    """
    with open(fpath, 'rb') as f:
        if sys.version_info < (3,):
            d = cPickle.load(f)
        else:
            d = cPickle.load(f, encoding='bytes')
            # decode utf8
            d_decoded = {}
            for k, v in d.items():
                d_decoded[k.decode('utf8')] = v
            d = d_decoded
    data = d['data']
    labels = d[label_key]
    data = data.reshape(data.shape[0], 3, 32, 32)
    return data, labels
def load_data(ROOT):
    """Loads CIFAR10 dataset.
    # Returns
        Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
    """
    # dirname = 'cifar-10-batches-py'
    # origin = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
    # path = get_file(dirname, origin=origin, untar=True)
    path = ROOT
    num_train_samples = 50000
    x_train = np.empty((num_train_samples, 3, 32, 32), dtype='uint8')
    y_train = np.empty((num_train_samples,), dtype='uint8')
    for i in range(1, 6):
        fpath = os.path.join(path, 'data_batch_' + str(i))
        (x_train[(i - 1) * 10000: i * 10000, :, :, :],
         y_train[(i - 1) * 10000: i * 10000]) = load_batch(fpath)
    fpath = os.path.join(path, 'test_batch')
    x_test, y_test = load_batch(fpath)
    y_train = np.reshape(y_train, (len(y_train), 1))
    y_test = np.reshape(y_test, (len(y_test), 1))
    if K.image_data_format() == 'channels_last':
        x_train = x_train.transpose(0, 2, 3, 1)
        x_test = x_test.transpose(0, 2, 3, 1)
    return (x_train, y_train), (x_test, y_test)


目录
相关文章
|
9月前
|
机器学习/深度学习 算法 TensorFlow
树叶识别系统python+Django网页界面+TensorFlow+算法模型+数据集+图像识别分类
树叶识别系统python+Django网页界面+TensorFlow+算法模型+数据集+图像识别分类
158 1
|
9月前
|
机器学习/深度学习 移动开发 算法
动物识别系统python+Django网页界面+TensorFlow算法模型+数据集训练
动物识别系统python+Django网页界面+TensorFlow算法模型+数据集训练
115 0
动物识别系统python+Django网页界面+TensorFlow算法模型+数据集训练
|
2月前
|
机器学习/深度学习 自然语言处理 机器人
【Tensorflow+自然语言处理+LSTM】搭建智能聊天客服机器人实战(附源码、数据集和演示 超详细)
【Tensorflow+自然语言处理+LSTM】搭建智能聊天客服机器人实战(附源码、数据集和演示 超详细)
399 7
|
2月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
【Python深度学习】Tensorflow+CNN进行人脸识别实战(附源码和数据集)
【Python深度学习】Tensorflow+CNN进行人脸识别实战(附源码和数据集)
443 2
|
2月前
|
文字识别 算法 TensorFlow
【Keras+计算机视觉+Tensorflow】OCR文字识别实战(附源码和数据集 超详细必看)
【Keras+计算机视觉+Tensorflow】OCR文字识别实战(附源码和数据集 超详细必看)
106 2
|
2月前
|
机器学习/深度学习 算法 TensorFlow
【Keras+计算机视觉+Tensorflow】实现基于YOLO和Deep Sort的目标检测与跟踪实战(附源码和数据集)
【Keras+计算机视觉+Tensorflow】实现基于YOLO和Deep Sort的目标检测与跟踪实战(附源码和数据集)
64 1
|
2月前
|
机器学习/深度学习 自然语言处理 机器人
【Tensorflow+自然语言处理+RNN】实现中文译英文的智能聊天机器人实战(附源码和数据集 超详细)
【Tensorflow+自然语言处理+RNN】实现中文译英文的智能聊天机器人实战(附源码和数据集 超详细)
54 1
|
2月前
|
机器学习/深度学习 数据采集 TensorFlow
【Tensorflow深度学习】实现手写字体识别、预测实战(附源码和数据集 超详细)
【Tensorflow深度学习】实现手写字体识别、预测实战(附源码和数据集 超详细)
153 1
|
2月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
【Keras+计算机视觉+Tensorflow】DCGAN对抗生成网络在MNIST手写数据集上实战(附源码和数据集 超详细)
【Keras+计算机视觉+Tensorflow】DCGAN对抗生成网络在MNIST手写数据集上实战(附源码和数据集 超详细)
92 0
|
2月前
|
自然语言处理 算法 Shell
【Rasa+Pycharm+Tensorflow】控制台实现智能客服问答实战(附源码和数据集 超详细)
【Rasa+Pycharm+Tensorflow】控制台实现智能客服问答实战(附源码和数据集 超详细)
102 0