Python3读取深度学习CIFAR-10数据集出现的若干问题解决

简介:   今天在看网上的视频学习深度学习的时候,用到了CIFAR-10数据集。当我兴高采烈的运行代码时,却发现了一些错误:# -*- coding: utf-8 -*-import pickl...

  今天在看网上的视频学习深度学习的时候,用到了CIFAR-10数据集。当我兴高采烈的运行代码时,却发现了一些错误:

# -*- coding: utf-8 -*-
import pickle as p
import numpy as np
import os


def load_CIFAR_batch(filename):
    """ 载入cifar数据集的一个batch """
    with open(filename, 'r') as f:
        datadict = p.load(f)
        X = datadict['data']
        Y = datadict['labels']
        X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
        Y = np.array(Y)
        return X, Y


def load_CIFAR10(ROOT):
    """ 载入cifar全部数据 """
    xs = []
    ys = []
    for b in range(1, 6):
        f = os.path.join(ROOT, 'data_batch_%d' % (b,))
        X, Y = load_CIFAR_batch(f)
        xs.append(X)
        ys.append(Y)
    Xtr = np.concatenate(xs)
    Ytr = np.concatenate(ys)
    del X, Y
    Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
    return Xtr, Ytr, Xte, Yte

  错误代码如下:

'gbk' codec can't decode byte 0x80 in position 0: illegal multibyte sequence

  于是乎开始各种搜索问题,问大佬,网上的答案都是类似:

这里写图片描述

  然而并没有解决问题!还是错误的!(我大概搜索了一下午吧,都是上面的答案)

  哇,就当我很绝望的时候,我终于发现了一个新奇的答案,抱着试一试的态度,尝试了一下:


def load_CIFAR_batch(filename):
    """ 载入cifar数据集的一个batch """
    with open(filename, 'rb') as f:
        datadict = p.load(f, encoding='latin1')
        X = datadict['data']
        Y = datadict['labels']
        X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
        Y = np.array(Y)
        return X, Y

  竟然成功了,这里没有报错了!欣喜之余,我就很好奇,encoding=’latin1’到底是啥玩意呢,以前没有见过啊?于是,我搜索了一下,了解到:

Latin1是ISO-8859-1的别名,有些环境下写作Latin-1。ISO-8859-1编码是单字节编码,向下兼容ASCII,其编码范围是0x00-0xFF,0x00-0x7F之间完全和ASCII一致,0x80-0x9F之间是控制字符,0xA0-0xFF之间是文字符号。

因为ISO-8859-1编码范围使用了单字节内的所有空间,在支持ISO-8859-1的系统中传输和存储其他任何编码的字节流都不会被抛弃。换言之,把其他任何编码的字节流当作ISO-8859-1编码看待都没有问题。这是个很重要的特性,MySQL数据库默认编码是Latin1就是利用了这个特性。ASCII编码是一个7位的容器,ISO-8859-1编码是一个8位的容器。

  还没等我高兴起来,运行后,又发现了一个问题:

memory error

  什么鬼?内存错误!哇,原来是数据大小的问题。

X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")

  这告诉我们每批数据都是10000 * 3 * 32 * 32,相当于超过3000万个浮点数。 float数据类型实际上与float64相同,意味着每个数字大小占8个字节。这意味着每个批次占用至少240 MB。你加载6这些(5训练+ 1测试)在总产量接近1.4 GB的数据。

 for b in range(1, 2):
        f = os.path.join(ROOT, 'data_batch_%d' % (b,))
        X, Y = load_CIFAR_batch(f)
        xs.append(X)
        ys.append(Y)

  所以如有可能,如上代码所示只能一次运行一批。

  到此为止,错误基本搞定,下面贴出正确代码:

# -*- coding: utf-8 -*-
import pickle as p
import numpy as np
import os


def load_CIFAR_batch(filename):
    """ 载入cifar数据集的一个batch """
    with open(filename, 'rb') as f:
        datadict = p.load(f, encoding='latin1')
        X = datadict['data']
        Y = datadict['labels']
        X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
        Y = np.array(Y)
        return X, Y


def load_CIFAR10(ROOT):
    """ 载入cifar全部数据 """
    xs = []
    ys = []
    for b in range(1, 2):
        f = os.path.join(ROOT, 'data_batch_%d' % (b,))
        X, Y = load_CIFAR_batch(f)
        xs.append(X)         #将所有batch整合起来
        ys.append(Y)
    Xtr = np.concatenate(xs) #使变成行向量,最终Xtr的尺寸为(50000,32,32,3)
    Ytr = np.concatenate(ys)
    del X, Y
    Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
    return Xtr, Ytr, Xte, Yte
import numpy as np
from julyedu.data_utils import load_CIFAR10
import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = (10.0, 8.0)
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# 载入CIFAR-10数据集
cifar10_dir = 'julyedu/datasets/cifar-10-batches-py'
X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)

# 看看数据集中的一些样本:每个类别展示一些
print('Training data shape: ', X_train.shape)
print('Training labels shape: ', y_train.shape)
print('Test data shape: ', X_test.shape)
print('Test labels shape: ', y_test.shape)

  顺便看一下CIFAR-10数据组成:

CIFAR-10数据组成


附件:CIFAR-10 python version下载地址

CIFAR-10官网

相关文章
|
11天前
|
机器学习/深度学习 算法框架/工具 数据库
使用Python实现深度学习模型:智能城市噪音监测与控制
使用Python实现深度学习模型:智能城市噪音监测与控制
37 1
|
7天前
|
机器学习/深度学习 数据采集 传感器
使用Python实现深度学习模型:智能土壤质量监测与管理
使用Python实现深度学习模型:智能土壤质量监测与管理
130 69
|
2天前
|
机器学习/深度学习 数据采集 存储
使用Python实现智能农业灌溉系统的深度学习模型
使用Python实现智能农业灌溉系统的深度学习模型
25 6
|
2天前
|
机器学习/深度学习 算法 编译器
Python程序到计算图一键转化,详解清华开源深度学习编译器MagPy
【10月更文挑战第26天】MagPy是一款由清华大学研发的开源深度学习编译器,可将Python程序一键转化为计算图,简化模型构建和优化过程。它支持多种深度学习框架,具备自动化、灵活性、优化性能好和易于扩展等特点,适用于模型构建、迁移、部署及教学研究。尽管MagPy具有诸多优势,但在算子支持、优化策略等方面仍面临挑战。
8 3
|
4天前
|
机器学习/深度学习 数据采集 算法框架/工具
使用Python实现深度学习模型:智能野生动物保护与监测
使用Python实现深度学习模型:智能野生动物保护与监测
19 5
|
6天前
|
机器学习/深度学习 数据采集 算法框架/工具
使用Python实现智能生态系统监测与保护的深度学习模型
使用Python实现智能生态系统监测与保护的深度学习模型
29 4
|
7天前
|
机器学习/深度学习 人工智能 算法
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
车辆车型识别,使用Python作为主要编程语言,通过收集多种车辆车型图像数据集,然后基于TensorFlow搭建卷积网络算法模型,并对数据集进行训练,最后得到一个识别精度较高的模型文件。再基于Django搭建web网页端操作界面,实现用户上传一张车辆图片识别其类型。
20 0
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
|
12天前
|
机器学习/深度学习 数据采集 传感器
使用Python实现深度学习模型:智能极端天气事件预测
使用Python实现深度学习模型:智能极端天气事件预测
44 3
|
15天前
|
机器学习/深度学习 数据可视化 TensorFlow
使用Python实现深度学习模型:智能海洋监测与保护
使用Python实现深度学习模型:智能海洋监测与保护
58 1
|
3天前
|
机器学习/深度学习 数据采集 数据可视化
使用Python实现深度学习模型:智能植物生长监测与优化
使用Python实现深度学习模型:智能植物生长监测与优化
21 0