TensorFlow使用DataSets加载数据

简介: 在TensorFlow中fit()函数可以接收numpy类型数据,前提数据量不大可以全部加载到内存中,但是如果数据量过大我们就需要将其按批次读取,转化成迭代器的形式,也就是DataSets

在TensorFlow中fit()函数可以接收numpy类型数据,前提数据量不大可以全部加载到内存中,但是如果数据量过大我们就需要将其按批次读取,转化成迭代器的形式,也就是DataSets


可以将 Dataset 实例直接传递给方法 fit()evaluate()predict()


如果使用DataSet就不需要像numpy数据那种在fit中指定batch_size了


完整代码:


"""

* Created with PyCharm

* 作者: 阿光

* 日期: 2022/1/2

* 时间: 19:29

* 描述:

"""

import tensorflow as tf

import tensorflow.keras.datasets.mnist

from keras import Input, Model

from keras.layers import Dense

from tensorflow import keras


(train_images, train_labels), (val_images, val_labels) = tensorflow.keras.datasets.mnist.load_data()


train_images, val_images = train_images / 255.0, val_images / 255.0


train_images = train_images.reshape(60000, 784)

val_images = val_images.reshape(10000, 784)


train_datasets = tf.data.Dataset.from_tensor_slices((train_images, train_labels))

train_datasets = train_datasets.shuffle(buffer_size=1024).batch(64)


val_datasets = tf.data.Dataset.from_tensor_slices((val_images, val_labels))

val_datasets = val_datasets.batch(64)



def get_model():

   inputs = Input(shape=(784,))

   outputs = Dense(10, activation='softmax')(inputs)

   model = Model(inputs, outputs)

   model.compile(

       optimizer=keras.optimizers.RMSprop(learning_rate=1e-3),

       loss=keras.losses.SparseCategoricalCrossentropy(),

       metrics=['accuracy']

   )

   return model



model = get_model()


model.fit(

   train_datasets,

   epochs=5,

   validation_data=val_datasets

)

目录
相关文章
|
4月前
|
TensorFlow 算法框架/工具
【Tensorflow+Keras】用Tensorflow.keras的方法替代keras.layers.merge
在TensorFlow 2.0和Keras中替代旧版keras.layers.merge函数的方法,使用了新的层如add, multiply, concatenate, average, 和 dot来实现常见的层合并操作。
34 1
|
7月前
|
数据采集 机器学习/深度学习 TensorFlow
TensorFlow中的数据加载与处理
【4月更文挑战第17天】本文介绍了在TensorFlow中进行数据加载与处理的方法。使用`tf.keras.datasets`模块可便捷加载MNIST等常见数据集,自定义数据集可通过`tf.data.Dataset`构建。利用`tf.data`模块构建输入管道,包括数据打乱、分批及重复,以优化训练效率。数据预处理涉及数据清洗、标准化/归一化以及使用`ImageDataGenerator`进行数据增强,这些步骤对模型性能和泛化至关重要。
|
7月前
|
存储 数据可视化 PyTorch
PyTorch中 Datasets & DataLoader 的介绍
PyTorch中 Datasets & DataLoader 的介绍
143 0
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch使用专题 | 2 :Pytorch中数据读取-Dataset、Dataloader 、TensorDataset 和 Sampler 的使用
介绍Pytorch中数据读取-Dataset、Dataloader 、TensorDataset 和 Sampler 的使用
|
机器学习/深度学习 数据采集 PyTorch
pytorch笔记:Dataset 和 DataLoader
pytorch笔记:Dataset 和 DataLoader
307 0
|
TensorFlow 算法框架/工具
TensorFlow加载cifar10数据集
TensorFlow加载cifar10数据集
133 0
TensorFlow加载cifar10数据集
|
数据采集 机器学习/深度学习 PyTorch
Pytorch中基于MNIST数据的torchvision工具包应用
Pytorch中基于MNIST数据的torchvision工具包应用
142 0
Pytorch中基于MNIST数据的torchvision工具包应用
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch教程[02]DataLoader与Dataset
Pytorch教程[02]DataLoader与Dataset
Pytorch教程[02]DataLoader与Dataset
|
机器学习/深度学习 PyTorch 算法框架/工具