TensorFlow进行不同模型和数据集之间的迁移学习和模型微调

简介: 迁移学习包括获取从一个问题中学习到的特征,然后将这些特征用于新的类似问题。例如,来自已学会识别浣熊的模型的特征可能对建立旨在识别狸猫的模型十分有用。


迁移学习包括获取从一个问题中学习到的特征,然后将这些特征用于新的类似问题。例如,来自已学会识别浣熊的模型的特征可能对建立旨在识别狸猫的模型十分有用。


对于数据集中的数据太少而无法从头开始训练完整模型的任务,通常会执行迁移学习。


在深度学习情境中,迁移学习最常见的形式是以下工作流:


  • 从之前训练的模型中获取层。
  • 冻结这些层,以避免在后续训练轮次中破坏它们包含的任何信息。
  • 在已冻结层的顶部添加一些新的可训练层。这些层会学习将旧特征转换为对新数据集的预测。
  • 在您的数据集上训练新层。
  • 最后一个可选步骤是微调,包括解冻上面获得的整个模型(或模型的一部分),然后在新数据上以极低的学习率对该模型进行重新训练。以增量方式使预训练特征适应新数据,有可能实现有意义的改进。


典型的迁移学习工作流
下面将介绍如何在 Keras 中实现典型的迁移学习工作流:


  • 实例化一个基础模型并加载预训练权重。
  • 通过设置 trainable = False 冻结基础模型中的所有层。
  • 根据基础模型中一个(或多个)层的输出创建一个新模型。
  • 在您的新数据集上训练新模型。
    请注意,另一种更轻量的工作流如下:
  • 实例化一个基础模型并加载预训练权重。
  • 通过该模型运行新的数据集,并记录基础模型中一个(或多个)层的输出。这一过程称为特征提取。
  • 使用该输出作为新的较小模型的输入数据。


第二种工作流有一个关键优势,即您只需在自己的数据上运行一次基础模型,而不是每个训练周期都运行一次。因此,它的速度更快,开销也更低。


但是,第二种工作流存在一个问题,即它不允许您在训练期间动态修改新模型的输入数据,在进行数据扩充时,这种修改必不可少。当新数据集的数据太少而无法从头开始训练完整模型时,任务通常会使用迁移学习,在这种情况下,数据扩充非常重要。因此,在接下来的篇幅中,我们将专注于第一种工作流。


完整代码


"""

* Created with PyCharm

* 作者: 阿光

* 日期: 2022/1/4

* 时间: 13:25

* 描述:

"""

import numpy as np

import tensorflow as tf

from keras import Model

from tensorflow import keras

from tensorflow.keras import layers


import tensorflow_datasets as tfds


tfds.disable_progress_bar()


train_ds, validation_ds, test_ds = tfds.load(

   "cats_vs_dogs",

   split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],

   as_supervised=True,

)


print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))

print(

   "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)

)

print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))


size = (150, 150)


train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))

validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))

test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))


batch_size = 32


train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)

validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)

test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)


base_model = keras.applications.Xception(

   weights='imagenet',

   input_shape=(150, 150, 3),

   include_top=False

)


base_model.trainable = False


data_augmentation = keras.Sequential([

   layers.experimental.preprocessing.RandomFlip('horizontal'),

   layers.experimental.preprocessing.RandomRotation(0.1)

])


inputs = keras.Input(shape=(150, 150, 3))

x = data_augmentation(inputs)


norm_layer = keras.layers.experimental.preprocessing.Normalization()

mean = np.array([127.5] * 3)

var = mean ** 2

x = norm_layer(x)

# norm_layer.set_weights([mean, var])

x = base_model(x, training=False)

x = layers.GlobalAveragePooling2D()(x)

x = layers.Dropout(0.2)(x)

outputs = layers.Dense(1)(x)

model = Model(inputs, outputs)


model.compile(

   optimizer=keras.optimizers.Adam(),

   loss=keras.losses.BinaryCrossentropy(from_logits=True),

   metrics=[keras.metrics.BinaryAccuracy()],

)


epochs = 20

model.fit(train_ds, epochs=epochs, validation_data=validation_ds)


base_model.trainable = True

model.summary()


model.compile(

   optimizer=keras.optimizers.Adam(1e-5),

   loss=keras.losses.BinaryCrossentropy(from_logits=True),

   metrics=[keras.metrics.BinaryAccuracy()],

)


epochs = 10

model.fit(train_ds, epochs=epochs, validation_data=validation_ds)

目录
相关文章
|
5月前
|
机器学习/深度学习 数据采集 PyTorch
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
|
5月前
|
搜索推荐 安全 UED
浅谈AARRR模型
浅谈AARRR模型
|
机器学习/深度学习 TensorFlow 算法框架/工具
AIAM 模型
AIAM(Artificial Intelligence and Music)模型是一种基于深度学习的音乐生成模型。
256 3
|
2月前
|
搜索推荐 语音技术
SenseVoice模型建议
8月更文挑战第4天
175 1
|
3月前
|
存储 人工智能 自然语言处理
大模型时代
【7月更文挑战第6天】大模型时代
54 5
|
1月前
|
人工智能 安全 测试技术
MetaLlama大模型
LLaMA 是一组基础语言模型,参数范围从 7B 到 65B,在大量公开数据上训练而成,性能优异。Llama 2 为 LLaMA 的升级版,参数规模扩大至 70 亿至 700 亿,特别优化了对话功能。Code Llama 基于 Llama 2 开发,专注于代码生成,提供不同参数规模的模型。这些模型可在多种平台上运行,包括官方 API、第三方封装库如 llama.cpp 和 ollama,以及通过 Hugging Face 的 transformers 库使用。此外,还提供了详细的模型申请及使用指南,便于开发者快速上手。相关链接包括 Meta 官方页面和 GitHub 仓库。
25 6
MetaLlama大模型
|
2月前
|
人工智能 算法 搜索推荐
你觉得大模型时代该出现什么?
【8月更文挑战第11天】大模型时代展望关键技术与基础设施升级,如量子计算支持、模型优化及专用芯片;模型层面探索多模态融合与自我解释能力;应用场景涵盖智能医疗、教育及城市管理等;社会人文领域则涉及新职业培训与伦理法规建设。
|
5月前
|
机器学习/深度学习 传感器 人工智能
世界模型是什么?
【2月更文挑战第9天】世界模型是什么?
522 3
世界模型是什么?
|
5月前
|
自然语言处理 知识图谱
你了解SCQA模型吗?
你了解SCQA模型吗?
336 0
你了解SCQA模型吗?
|
5月前
使用xxmix9realistic_v40.safetensors模型
使用xxmix9realistic_v40.safetensors模型
229 0
下一篇
无影云桌面