TensorFlow进行不同模型和数据集之间的迁移学习和模型微调

简介: 迁移学习包括获取从一个问题中学习到的特征,然后将这些特征用于新的类似问题。例如,来自已学会识别浣熊的模型的特征可能对建立旨在识别狸猫的模型十分有用。


迁移学习包括获取从一个问题中学习到的特征,然后将这些特征用于新的类似问题。例如,来自已学会识别浣熊的模型的特征可能对建立旨在识别狸猫的模型十分有用。


对于数据集中的数据太少而无法从头开始训练完整模型的任务,通常会执行迁移学习。


在深度学习情境中,迁移学习最常见的形式是以下工作流:


  • 从之前训练的模型中获取层。
  • 冻结这些层,以避免在后续训练轮次中破坏它们包含的任何信息。
  • 在已冻结层的顶部添加一些新的可训练层。这些层会学习将旧特征转换为对新数据集的预测。
  • 在您的数据集上训练新层。
  • 最后一个可选步骤是微调,包括解冻上面获得的整个模型(或模型的一部分),然后在新数据上以极低的学习率对该模型进行重新训练。以增量方式使预训练特征适应新数据,有可能实现有意义的改进。


典型的迁移学习工作流
下面将介绍如何在 Keras 中实现典型的迁移学习工作流:


  • 实例化一个基础模型并加载预训练权重。
  • 通过设置 trainable = False 冻结基础模型中的所有层。
  • 根据基础模型中一个(或多个)层的输出创建一个新模型。
  • 在您的新数据集上训练新模型。
    请注意,另一种更轻量的工作流如下:
  • 实例化一个基础模型并加载预训练权重。
  • 通过该模型运行新的数据集,并记录基础模型中一个(或多个)层的输出。这一过程称为特征提取。
  • 使用该输出作为新的较小模型的输入数据。


第二种工作流有一个关键优势,即您只需在自己的数据上运行一次基础模型,而不是每个训练周期都运行一次。因此,它的速度更快,开销也更低。


但是,第二种工作流存在一个问题,即它不允许您在训练期间动态修改新模型的输入数据,在进行数据扩充时,这种修改必不可少。当新数据集的数据太少而无法从头开始训练完整模型时,任务通常会使用迁移学习,在这种情况下,数据扩充非常重要。因此,在接下来的篇幅中,我们将专注于第一种工作流。


完整代码


"""

* Created with PyCharm

* 作者: 阿光

* 日期: 2022/1/4

* 时间: 13:25

* 描述:

"""

import numpy as np

import tensorflow as tf

from keras import Model

from tensorflow import keras

from tensorflow.keras import layers


import tensorflow_datasets as tfds


tfds.disable_progress_bar()


train_ds, validation_ds, test_ds = tfds.load(

   "cats_vs_dogs",

   split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],

   as_supervised=True,

)


print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))

print(

   "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)

)

print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))


size = (150, 150)


train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))

validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))

test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))


batch_size = 32


train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)

validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)

test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)


base_model = keras.applications.Xception(

   weights='imagenet',

   input_shape=(150, 150, 3),

   include_top=False

)


base_model.trainable = False


data_augmentation = keras.Sequential([

   layers.experimental.preprocessing.RandomFlip('horizontal'),

   layers.experimental.preprocessing.RandomRotation(0.1)

])


inputs = keras.Input(shape=(150, 150, 3))

x = data_augmentation(inputs)


norm_layer = keras.layers.experimental.preprocessing.Normalization()

mean = np.array([127.5] * 3)

var = mean ** 2

x = norm_layer(x)

# norm_layer.set_weights([mean, var])

x = base_model(x, training=False)

x = layers.GlobalAveragePooling2D()(x)

x = layers.Dropout(0.2)(x)

outputs = layers.Dense(1)(x)

model = Model(inputs, outputs)


model.compile(

   optimizer=keras.optimizers.Adam(),

   loss=keras.losses.BinaryCrossentropy(from_logits=True),

   metrics=[keras.metrics.BinaryAccuracy()],

)


epochs = 20

model.fit(train_ds, epochs=epochs, validation_data=validation_ds)


base_model.trainable = True

model.summary()


model.compile(

   optimizer=keras.optimizers.Adam(1e-5),

   loss=keras.losses.BinaryCrossentropy(from_logits=True),

   metrics=[keras.metrics.BinaryAccuracy()],

)


epochs = 10

model.fit(train_ds, epochs=epochs, validation_data=validation_ds)

目录
相关文章
|
7月前
|
搜索推荐 安全 UED
浅谈AARRR模型
浅谈AARRR模型
|
7月前
|
机器学习/深度学习 数据采集 PyTorch
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
|
5月前
|
存储 人工智能 自然语言处理
大模型时代
【7月更文挑战第6天】大模型时代
70 5
|
1月前
|
机器学习/深度学习 自然语言处理
MGTE系列模型
【10月更文挑战第15天】
66 9
|
4月前
|
人工智能 算法 搜索推荐
你觉得大模型时代该出现什么?
【8月更文挑战第11天】大模型时代展望关键技术与基础设施升级,如量子计算支持、模型优化及专用芯片;模型层面探索多模态融合与自我解释能力;应用场景涵盖智能医疗、教育及城市管理等;社会人文领域则涉及新职业培训与伦理法规建设。
|
7月前
|
机器学习/深度学习 人工智能 PyTorch
LLM 大模型学习必知必会系列(四):LLM训练理论篇以及Transformer结构模型详解
LLM 大模型学习必知必会系列(四):LLM训练理论篇以及Transformer结构模型详解
LLM 大模型学习必知必会系列(四):LLM训练理论篇以及Transformer结构模型详解
|
7月前
|
机器学习/深度学习 传感器 人工智能
世界模型是什么?
【2月更文挑战第9天】世界模型是什么?
758 3
世界模型是什么?
|
7月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch与迁移学习:利用预训练模型提升性能
【4月更文挑战第18天】PyTorch支持迁移学习,助力提升深度学习性能。预训练模型(如ResNet、VGG)在大规模数据集(如ImageNet)训练后,可在新任务中加速训练,提高准确率。通过选择模型、加载预训练权重、修改结构和微调,可适应不同任务需求。迁移学习节省资源,但也需考虑源任务与目标任务的相似度及超参数选择。实践案例显示,预训练模型能有效提升小数据集上的图像分类任务性能。未来,迁移学习将继续在深度学习领域发挥重要作用。
|
机器学习/深度学习 算法 PyTorch
使用Pytorch实现对比学习SimCLR 进行自监督预训练
SimCLR(Simple Framework for Contrastive Learning of Representations)是一种学习图像表示的自监督技术。 与传统的监督学习方法不同,SimCLR 不依赖标记数据来学习有用的表示。 它利用对比学习框架来学习一组有用的特征,这些特征可以从未标记的图像中捕获高级语义信息。
1115 1
|
机器学习/深度学习 TensorFlow 算法框架/工具
【深度学习】基于tensorflow的小型物体识别训练(数据集:CIFAR-10)
【深度学习】基于tensorflow的小型物体识别训练(数据集:CIFAR-10)
363 0