TensorFlow训练网络两种方式

简介: TensorFlow训练网络有两种方式,一种是基于tensor(array),另外一种是迭代器两种方式区别是:第一种是要加载全部数据形成一个tensor,然后调用model.fit()然后指定参数batch_size进行将所有数据进行分批训练第二种是自己先将数据分批形成一个迭代器,然后遍历这个迭代器,分别训练每个批次的数据

TensorFlow训练网络有两种方式,一种是基于tensor(array),另外一种是迭代器
两种方式区别是:
第一种是要加载全部数据形成一个tensor,然后调用model.fit()然后指定参数batch_size进行将所有数据进行分批训练
第二种是自己先将数据分批形成一个迭代器,然后遍历这个迭代器,分别训练每个批次的数据


方式一:通过迭代器


IMAGE_SIZE = 1000


# step1:加载数据集

(train_images, train_labels), (val_images, val_labels) = tf.keras.datasets.mnist.load_data()


# step2:将图像归一化

train_images, val_images = train_images / 255.0, val_images / 255.0


# step3:设置训练集大小

train_images = train_images[:IMAGE_SIZE]

val_images = val_images[:IMAGE_SIZE]

train_labels = train_labels[:IMAGE_SIZE]

val_labels = val_labels[:IMAGE_SIZE]


# step4:将图像的维度变为(IMAGE_SIZE,28,28,1)

train_images = tf.expand_dims(train_images, axis=3)

val_images = tf.expand_dims(val_images, axis=3)


# step5:将图像的尺寸变为(32,32)

train_images = tf.image.resize(train_images, [32, 32])

val_images = tf.image.resize(val_images, [32, 32])


# step6:将数据变为迭代器

train_loader = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).batch(32)

val_loader = tf.data.Dataset.from_tensor_slices((val_images, val_labels)).batch(IMAGE_SIZE)


# step5:导入模型

model = LeNet5()


# 让模型知道输入数据的形式

model.build(input_shape=(1, 32, 32, 1))


# 结局Output Shape为 multiple

model.call(Input(shape=(32, 32, 1)))


# step6:编译模型

model.compile(optimizer='adam',

             loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),

             metrics=['accuracy'])


# 权重保存路径

checkpoint_path = "./weight/cp.ckpt"


# 回调函数,用户保存权重

save_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,

                                                  save_best_only=True,

                                                  save_weights_only=True,

                                                  monitor='val_loss',

                                                  verbose=0)


EPOCHS = 11


for epoch in range(1, EPOCHS):

   # 每个批次训练集误差

   train_epoch_loss_avg = tf.keras.metrics.Mean()

   # 每个批次训练集精度

   train_epoch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

   # 每个批次验证集误差

   val_epoch_loss_avg = tf.keras.metrics.Mean()

   # 每个批次验证集精度

   val_epoch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()


   for x, y in train_loader:

       history = model.fit(x,

                           y,

                           validation_data=val_loader,

                           callbacks=[save_callback],

                           verbose=0)


       # 更新误差,保留上次

       train_epoch_loss_avg.update_state(history.history['loss'][0])

       # 更新精度,保留上次

       train_epoch_accuracy.update_state(y, model(x, training=True))


       val_epoch_loss_avg.update_state(history.history['val_loss'][0])

       val_epoch_accuracy.update_state(next(iter(val_loader))[1], model(next(iter(val_loader))[0], training=True))


   # 使用.result()计算每个批次的误差和精度结果

   print("Epoch {:d}: trainLoss: {:.3f}, trainAccuracy: {:.3%} valLoss: {:.3f}, valAccuracy: {:.3%}".format(epoch,

                                                                                                            train_epoch_loss_avg.result(),

                                                                                                            train_epoch_accuracy.result(),

                                                                                                            val_epoch_loss_avg.result(),

                                                                                                            val_epoch_accuracy.result()))


方式二:适用model.fit()进行分批训练


import model_sequential


(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()


# step2:将图像归一化

train_images, test_images = train_images / 255.0, test_images / 255.0


# step3:将图像的维度变为(60000,28,28,1)

train_images = tf.expand_dims(train_images, axis=3)

test_images = tf.expand_dims(test_images, axis=3)


# step4:将图像尺寸改为(60000,32,32,1)

train_images = tf.image.resize(train_images, [32, 32])

test_images = tf.image.resize(test_images, [32, 32])


# step5:导入模型

# history = LeNet5()

history = model_sequential.LeNet()


# 让模型知道输入数据的形式

history.build(input_shape=(1, 32, 32, 1))

# history(tf.zeros([1, 32, 32, 1]))


# 结局Output Shape为 multiple

history.call(Input(shape=(32, 32, 1)))

history.summary()


# step6:编译模型

history.compile(optimizer='adam',

               loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),

               metrics=['accuracy'])


# 权重保存路径

checkpoint_path = "./weight/cp.ckpt"


# 回调函数,用户保存权重

save_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,

                                                  save_best_only=True,

                                                  save_weights_only=True,

                                                  monitor='val_loss',

                                                  verbose=1)

# step7:训练模型

history = history.fit(train_images,

                     train_labels,

                     epochs=10,

                     batch_size=32,

                     validation_data=(test_images, test_labels),

                     callbacks=[save_callback])

目录
相关文章
|
消息中间件 XML 网络协议
『NLog』.Net使用NLog使用方式及详细配置(输出至文件/RabbitMQ/远程网络Tcp)
📣读完这篇文章里你能收获到 - Nlog输出至文件/RabbitMQ/远程网络Tcp配置文档 - Nlog配置参数详解 - .NET CORE项目接入
4150 0
『NLog』.Net使用NLog使用方式及详细配置(输出至文件/RabbitMQ/远程网络Tcp)
|
算法 Ubuntu 物联网
ESP32-C3入门教程 网络 篇(二、 Wi-Fi 配网 — Smart_config方式 和 BlueIF方式)
经过上一篇的WiFI入门篇,我们知道了WiFi初始化方式 和学会了WiFi的几种工作方式, 在实际应用中,环境复杂多变,在固件中输入SSID 的方式太不通用了, 所以肯定是需要学习一下如何在不同的环境中联网,就是所谓的配网。 ESP32-C3的配网方式有多种,本文主要说明测试 Smart方式 和 BlueIF方式。
1191 0
ESP32-C3入门教程 网络 篇(二、 Wi-Fi 配网 — Smart_config方式 和 BlueIF方式)
|
Linux 编译器 开发工具
Linux网络环境配置:(内含:随机ip和固定ip设置方式)
Linux网络环境配置:(内含:随机ip和固定ip设置方式)
295 0
Linux网络环境配置:(内含:随机ip和固定ip设置方式)
|
安全 网络安全 网络虚拟化
混合云网络构建方式|学习笔记
快速学习混合云网络构建方式
混合云网络构建方式|学习笔记
|
存储 Android开发 文件存储
WebView加载页面的两种方式——网络页面和本地页面
WebView加载页面的两种方式 一、加载网络页面   加载网络页面,是最简单的一种方式,只需要传入http的URL就可以,实现WebView加载网络页面 代码如下图: 二、加载本地页面   1、加载assets目录下的HTML页面: 加载assets目录的页面,大多数可以用来做页面数据的存储打包...
2412 0
|
视频直播
为什么说移动端网络视频直播系统逐渐成为了一种主流方式
移动端的直播系统已逐渐成为了网络视频直播系统的主要形式之一,这归功于手机的便携式,也突破了时间和地域的束缚,对于用户来说,学习和操作的成本也很低。
为什么说移动端网络视频直播系统逐渐成为了一种主流方式