TensorFlow进行不同模型和数据集之间的迁移学习和模型微调

本文涉及的产品
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
模型训练 PAI-DLC,5000CU*H 3个月
视觉智能开放平台,分割抠图1万点
简介: TensorFlow进行不同模型和数据集之间的迁移学习和模型微调

迁移学习包括获取从一个问题中学习到的特征,然后将这些特征用于新的类似问题。例如,来自已学会识别浣熊的模型的特征可能对建立旨在识别狸猫的模型十分有用。

对于数据集中的数据太少而无法从头开始训练完整模型的任务,通常会执行迁移学习。

在深度学习情境中,迁移学习最常见的形式是以下工作流

  • 从之前训练的模型中获取层。
  • 冻结这些层,以避免在后续训练轮次中破坏它们包含的任何信息。
  • 在已冻结层的顶部添加一些新的可训练层。这些层会学习将旧特征转换为对新数据集的预测。
  • 在您的数据集上训练新层。
  • 最后一个可选步骤是微调,包括解冻上面获得的整个模型(或模型的一部分),然后在新数据上以极低的学习率对该模型进行重新训练。以增量方式使预训练特征适应新数据,有可能实现有意义的改进。

典型的迁移学习工作流

下面将介绍如何在 Keras 中实现典型的迁移学习工作流:

  • 实例化一个基础模型并加载预训练权重
  • 通过设置 trainable = False 冻结基础模型中的所有层。
  • 根据基础模型中一个(或多个)层的输出创建一个新模型。
  • 在您的新数据集上训练新模型。
    请注意,另一种更轻量的工作流如下:
  • 实例化一个基础模型并加载预训练权重。
  • 通过该模型运行新的数据集,并记录基础模型中一个(或多个)层的输出。这一过程称为特征提取。
  • 使用该输出作为新的较小模型的输入数据。

第二种工作流有一个关键优势,即您只需在自己的数据上运行一次基础模型,而不是每个训练周期都运行一次。因此,它的速度更快,开销也更低。

但是,第二种工作流存在一个问题,即它不允许您在训练期间动态修改新模型的输入数据,在进行数据扩充时,这种修改必不可少。当新数据集的数据太少而无法从头开始训练完整模型时,任务通常会使用迁移学习,在这种情况下,数据扩充非常重要。因此,在接下来的篇幅中,我们将专注于第一种工作流。

完整代码

"""
 * Created with PyCharm
 * 作者: 阿光
 * 日期: 2022/1/4
 * 时间: 13:25
 * 描述:
"""
import numpy as np
import tensorflow as tf
from keras import Model
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,
)
print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
size = (150, 150)
train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))
batch_size = 32
train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False
)
base_model.trainable = False
data_augmentation = keras.Sequential([
    layers.experimental.preprocessing.RandomFlip('horizontal'),
    layers.experimental.preprocessing.RandomRotation(0.1)
])
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)
norm_layer = keras.layers.experimental.preprocessing.Normalization()
mean = np.array([127.5] * 3)
var = mean ** 2
x = norm_layer(x)
# norm_layer.set_weights([mean, var])
x = base_model(x, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(1)(x)
model = Model(inputs, outputs)
model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)
epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
base_model.trainable = True
model.summary()
model.compile(
    optimizer=keras.optimizers.Adam(1e-5),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)
epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)


目录
相关文章
|
30天前
|
机器学习/深度学习 TensorFlow 算法框架/工具
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
将Keras训练好的.hdf5模型转换为TensorFlow的.pb模型,然后再转换为TensorRT支持的.uff格式,并提供了转换代码和测试步骤。
68 3
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
|
28天前
|
数据采集 TensorFlow 算法框架/工具
【大作业-03】手把手教你用tensorflow2.3训练自己的分类数据集
本教程详细介绍了如何使用TensorFlow 2.3训练自定义图像分类数据集,涵盖数据集收集、整理、划分及模型训练与测试全过程。提供完整代码示例及图形界面应用开发指导,适合初学者快速上手。[教程链接](https://www.bilibili.com/video/BV1rX4y1A7N8/),配套视频更易理解。
34 0
【大作业-03】手把手教你用tensorflow2.3训练自己的分类数据集
|
14天前
|
机器学习/深度学习 人工智能 算法
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
车辆车型识别,使用Python作为主要编程语言,通过收集多种车辆车型图像数据集,然后基于TensorFlow搭建卷积网络算法模型,并对数据集进行训练,最后得到一个识别精度较高的模型文件。再基于Django搭建web网页端操作界面,实现用户上传一张车辆图片识别其类型。
55 0
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
|
2月前
|
机器学习/深度学习 人工智能 算法
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
鸟类识别系统。本系统采用Python作为主要开发语言,通过使用加利福利亚大学开源的200种鸟类图像作为数据集。使用TensorFlow搭建ResNet50卷积神经网络算法模型,然后进行模型的迭代训练,得到一个识别精度较高的模型,然后在保存为本地的H5格式文件。在使用Django开发Web网页端操作界面,实现用户上传一张鸟类图像,识别其名称。
104 12
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
|
30天前
|
机器学习/深度学习 移动开发 TensorFlow
深度学习之格式转换笔记(四):Keras(.h5)模型转化为TensorFlow(.pb)模型
本文介绍了如何使用Python脚本将Keras模型转换为TensorFlow的.pb格式模型,包括加载模型、重命名输出节点和量化等步骤,以便在TensorFlow中进行部署和推理。
69 0
|
3月前
|
API UED 开发者
如何在Uno Platform中轻松实现流畅动画效果——从基础到优化,全方位打造用户友好的动态交互体验!
【8月更文挑战第31天】在开发跨平台应用时,确保用户界面流畅且具吸引力至关重要。Uno Platform 作为多端统一的开发框架,不仅支持跨系统应用开发,还能通过优化实现流畅动画,增强用户体验。本文探讨了Uno Platform中实现流畅动画的多个方面,包括动画基础、性能优化、实践技巧及问题排查,帮助开发者掌握具体优化策略,提升应用质量与用户满意度。通过合理利用故事板、减少布局复杂性、使用硬件加速等技术,结合异步方法与预设缓存技巧,开发者能够创建美观且流畅的动画效果。
77 0
|
3月前
|
C# 开发者 前端开发
揭秘混合开发新趋势:Uno Platform携手Blazor,教你一步到位实现跨平台应用,代码复用不再是梦!
【8月更文挑战第31天】随着前端技术的发展,混合开发日益受到开发者青睐。本文详述了如何结合.NET生态下的两大框架——Uno Platform与Blazor,进行高效混合开发。Uno Platform基于WebAssembly和WebGL技术,支持跨平台应用构建;Blazor则让C#成为可能的前端开发语言,实现了客户端与服务器端逻辑共享。二者结合不仅提升了代码复用率与跨平台能力,还简化了项目维护并增强了Web应用性能。文中提供了从环境搭建到示例代码的具体步骤,并展示了如何创建一个简单的计数器应用,帮助读者快速上手混合开发。
76 0
|
2月前
|
机器学习/深度学习 数据挖掘 TensorFlow
解锁Python数据分析新技能,TensorFlow&PyTorch双引擎驱动深度学习实战盛宴
在数据驱动时代,Python凭借简洁的语法和强大的库支持,成为数据分析与机器学习的首选语言。Pandas和NumPy是Python数据分析的基础,前者提供高效的数据处理工具,后者则支持科学计算。TensorFlow与PyTorch作为深度学习领域的两大框架,助力数据科学家构建复杂神经网络,挖掘数据深层价值。通过Python打下的坚实基础,结合TensorFlow和PyTorch的强大功能,我们能在数据科学领域探索无限可能,解决复杂问题并推动科研进步。
58 0
|
2月前
|
机器学习/深度学习 数据挖掘 TensorFlow
从数据小白到AI专家:Python数据分析与TensorFlow/PyTorch深度学习的蜕变之路
【9月更文挑战第10天】从数据新手成长为AI专家,需先掌握Python基础语法,并学会使用NumPy和Pandas进行数据分析。接着,通过Matplotlib和Seaborn实现数据可视化,最后利用TensorFlow或PyTorch探索深度学习。这一过程涉及从数据清洗、可视化到构建神经网络的多个步骤,每一步都需不断实践与学习。借助Python的强大功能及各类库的支持,你能逐步解锁数据的深层价值。
62 0
|
3月前
|
持续交付 测试技术 jenkins
JSF 邂逅持续集成,紧跟技术热点潮流,开启高效开发之旅,引发开发者强烈情感共鸣
【8月更文挑战第31天】在快速发展的软件开发领域,JavaServer Faces(JSF)这一强大的Java Web应用框架与持续集成(CI)结合,可显著提升开发效率及软件质量。持续集成通过频繁的代码集成及自动化构建测试,实现快速反馈、高质量代码、加强团队协作及简化部署流程。以Jenkins为例,配合Maven或Gradle,可轻松搭建JSF项目的CI环境,通过JUnit和Selenium编写自动化测试,确保每次构建的稳定性和正确性。
60 0
下一篇
无影云桌面