TensorFlow训练网络两种方式

简介: TensorFlow训练网络两种方式

TensorFlow训练网络有两种方式,一种是基于tensor(array),另外一种是迭代器

两种方式区别是:

第一种是要加载全部数据形成一个tensor,然后调用model.fit()然后指定参数batch_size进行将所有数据进行分批训练

第二种是自己先将数据分批形成一个迭代器,然后遍历这个迭代器,分别训练每个批次的数据

方式一:通过迭代器

IMAGE_SIZE = 1000
# step1:加载数据集
(train_images, train_labels), (val_images, val_labels) = tf.keras.datasets.mnist.load_data()
# step2:将图像归一化
train_images, val_images = train_images / 255.0, val_images / 255.0
# step3:设置训练集大小
train_images = train_images[:IMAGE_SIZE]
val_images = val_images[:IMAGE_SIZE]
train_labels = train_labels[:IMAGE_SIZE]
val_labels = val_labels[:IMAGE_SIZE]
# step4:将图像的维度变为(IMAGE_SIZE,28,28,1)
train_images = tf.expand_dims(train_images, axis=3)
val_images = tf.expand_dims(val_images, axis=3)
# step5:将图像的尺寸变为(32,32)
train_images = tf.image.resize(train_images, [32, 32])
val_images = tf.image.resize(val_images, [32, 32])
# step6:将数据变为迭代器
train_loader = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).batch(32)
val_loader = tf.data.Dataset.from_tensor_slices((val_images, val_labels)).batch(IMAGE_SIZE)
# step5:导入模型
model = LeNet5()
# 让模型知道输入数据的形式
model.build(input_shape=(1, 32, 32, 1))
# 结局Output Shape为 multiple
model.call(Input(shape=(32, 32, 1)))
# step6:编译模型
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
# 权重保存路径
checkpoint_path = "./weight/cp.ckpt"
# 回调函数,用户保存权重
save_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                   save_best_only=True,
                                                   save_weights_only=True,
                                                   monitor='val_loss',
                                                   verbose=0)
EPOCHS = 11
for epoch in range(1, EPOCHS):
    # 每个批次训练集误差
    train_epoch_loss_avg = tf.keras.metrics.Mean()
    # 每个批次训练集精度
    train_epoch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    # 每个批次验证集误差
    val_epoch_loss_avg = tf.keras.metrics.Mean()
    # 每个批次验证集精度
    val_epoch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    for x, y in train_loader:
        history = model.fit(x,
                            y,
                            validation_data=val_loader,
                            callbacks=[save_callback],
                            verbose=0)
        # 更新误差,保留上次
        train_epoch_loss_avg.update_state(history.history['loss'][0])
        # 更新精度,保留上次
        train_epoch_accuracy.update_state(y, model(x, training=True))
        val_epoch_loss_avg.update_state(history.history['val_loss'][0])
        val_epoch_accuracy.update_state(next(iter(val_loader))[1], model(next(iter(val_loader))[0], training=True))
    # 使用.result()计算每个批次的误差和精度结果
    print("Epoch {:d}: trainLoss: {:.3f}, trainAccuracy: {:.3%} valLoss: {:.3f}, valAccuracy: {:.3%}".format(epoch,

方式二:适用model.fit()进行分批训练

import model_sequential
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
# step2:将图像归一化
train_images, test_images = train_images / 255.0, test_images / 255.0
# step3:将图像的维度变为(60000,28,28,1)
train_images = tf.expand_dims(train_images, axis=3)
test_images = tf.expand_dims(test_images, axis=3)
# step4:将图像尺寸改为(60000,32,32,1)
train_images = tf.image.resize(train_images, [32, 32])
test_images = tf.image.resize(test_images, [32, 32])
# step5:导入模型
# history = LeNet5()
history = model_sequential.LeNet()
# 让模型知道输入数据的形式
history.build(input_shape=(1, 32, 32, 1))
# history(tf.zeros([1, 32, 32, 1]))
# 结局Output Shape为 multiple
history.call(Input(shape=(32, 32, 1)))
history.summary()
# step6:编译模型
history.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
# 权重保存路径
checkpoint_path = "./weight/cp.ckpt"
# 回调函数,用户保存权重
save_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                   save_best_only=True,
                                                   save_weights_only=True,
                                                   monitor='val_loss',
                                                   verbose=1)
# step7:训练模型
history = history.fit(train_images,
                      train_labels,
                      epochs=10,
                      batch_size=32,
                      validation_data=(test_images, test_labels),
                      callbacks=[save_callback])


目录
相关文章
|
3月前
|
机器学习/深度学习 人工智能 算法
AI 基础知识从 0.6 到 0.7—— 彻底拆解深度神经网络训练的五大核心步骤
本文以一个经典的PyTorch手写数字识别代码示例为引子,深入剖析了简洁代码背后隐藏的深度神经网络(DNN)训练全过程。
743 56
|
1月前
|
机器学习/深度学习 数据可视化 网络架构
PINN训练新思路:把初始条件和边界约束嵌入网络架构,解决多目标优化难题
PINNs训练难因多目标优化易失衡。通过设计硬约束网络架构,将初始与边界条件内嵌于模型输出,可自动满足约束,仅需优化方程残差,简化训练过程,提升稳定性与精度,适用于气候、生物医学等高要求仿真场景。
213 4
PINN训练新思路:把初始条件和边界约束嵌入网络架构,解决多目标优化难题
|
7月前
|
机器学习/深度学习 存储 算法
NoProp:无需反向传播,基于去噪原理的非全局梯度传播神经网络训练,可大幅降低内存消耗
反向传播算法虽是深度学习基石,但面临内存消耗大和并行扩展受限的问题。近期,牛津大学等机构提出NoProp方法,通过扩散模型概念,将训练重塑为分层去噪任务,无需全局前向或反向传播。NoProp包含三种变体(DT、CT、FM),具备低内存占用与高效训练优势,在CIFAR-10等数据集上达到与传统方法相当的性能。其层间解耦特性支持分布式并行训练,为无梯度深度学习提供了新方向。
263 1
NoProp:无需反向传播,基于去噪原理的非全局梯度传播神经网络训练,可大幅降低内存消耗
|
8月前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【害虫识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
害虫识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了12种常见的害虫种类数据集【"蚂蚁(ants)", "蜜蜂(bees)", "甲虫(beetle)", "毛虫(catterpillar)", "蚯蚓(earthworms)", "蜚蠊(earwig)", "蚱蜢(grasshopper)", "飞蛾(moth)", "鼻涕虫(slug)", "蜗牛(snail)", "黄蜂(wasp)", "象鼻虫(weevil)"】 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Djan
467 1
基于Python深度学习的【害虫识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
|
9月前
|
机器学习/深度学习 文件存储 异构计算
YOLOv11改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
YOLOv11改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
967 18
YOLOv11改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
|
9月前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【蘑菇识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
蘑菇识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了9种常见的蘑菇种类数据集【"香菇(Agaricus)", "毒鹅膏菌(Amanita)", "牛肝菌(Boletus)", "网状菌(Cortinarius)", "毒镰孢(Entoloma)", "湿孢菌(Hygrocybe)", "乳菇(Lactarius)", "红菇(Russula)", "松茸(Suillus)"】 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Django框架搭建了一个Web网页平台可视化操作界面,
879 11
基于Python深度学习的【蘑菇识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
|
9月前
|
机器学习/深度学习 数据可视化 API
DeepSeek生成对抗网络(GAN)的训练与应用
生成对抗网络(GANs)是深度学习的重要技术,能生成逼真的图像、音频和文本数据。通过生成器和判别器的对抗训练,GANs实现高质量数据生成。DeepSeek提供强大工具和API,简化GAN的训练与应用。本文介绍如何使用DeepSeek构建、训练GAN,并通过代码示例帮助掌握相关技巧,涵盖模型定义、训练过程及图像生成等环节。
|
9月前
|
机器学习/深度学习 文件存储 异构计算
RT-DETR改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
RT-DETR改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
199 1
|
11月前
|
SQL 安全 网络安全
网络安全与信息安全:知识分享####
【10月更文挑战第21天】 随着数字化时代的快速发展,网络安全和信息安全已成为个人和企业不可忽视的关键问题。本文将探讨网络安全漏洞、加密技术以及安全意识的重要性,并提供一些实用的建议,帮助读者提高自身的网络安全防护能力。 ####
252 17
|
11月前
|
SQL 安全 网络安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
随着互联网的普及,网络安全问题日益突出。本文将从网络安全漏洞、加密技术和安全意识三个方面进行探讨,旨在提高读者对网络安全的认识和防范能力。通过分析常见的网络安全漏洞,介绍加密技术的基本原理和应用,以及强调安全意识的重要性,帮助读者更好地保护自己的网络信息安全。
217 10

热门文章

最新文章