Tensorflow加载预训练模型的特殊操作

本文涉及的产品
交互式建模 PAI-DSW,5000CU*H 3个月
简介: Tensorflow加载预训练模型的特殊操作

Tensorflow加载预训练模型的特殊操作


最近看到一个巨牛的人工智能教程,分享一下给大家。教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家。平时碎片时间可以当小说看,【点这里可以去膜拜一下大神的“小说”】。

在前面的文章【Tensorflow加载预训练模型和保存模型】中介绍了如何保存训练好的模型,已经将预训练好的模型参数加载到当前网络。这些属于常规操作,即预训练的模型与当前网络结构的命名完全一致。

本文介绍一些不常规的操作:

如何只加载部分参数?

如何从两个模型中加载不同部分参数?

当预训练的模型的命名与当前定义的网络中的参数命名不一致时该怎么办?

1 只加载部分参数

举个例子,对已有的网络结构做了细微修改,例如只改了几层卷积通道数。如果从头训练显然没有finetune收敛速度快,但是模型又没法全部加载。此时,只需将未修改部分参数加载到当前网络即可。假设修改过的卷积层名称包含`conv_``,示例代码如下:

import tensorflow as tf
def restore(sess, ckpt_path):
  vars = tf.trainable_variables()
  vars = [v for v vars if not "conv_1" in v.name]
    saver = tf.train.Saver(var_list=vars)
  saver.restore(sess, ckpt_path)

2 从两个预训练模型中加载不同部分参数

如果需要从两个不同的预训练模型中加载不同部分参数,例如,网络中的前半部分用一个预训练模型参数,后半部分用另一个预训练模型中的参数,示例代码如下:

import tensorflow as tf
def restore(sess, ckpt_path):
  vars = tf.trainable_variables()
  model_1_vars = [v for v vars if "model_1" in v.name]
  model_2_vars = [v for v vars if "model_2" in v.name]
    saver_1 = tf.train.Saver(var_list=model_1_vars)
    saver_2 = tf.train.Saver(var_list=model_2_vars)
  saver_1 .restore(sess, ckpt_path)
  saver_2 .restore(sess, ckpt_path)

3 从参数名称不一致的模型中加载参数

举个例子,例如,预训练的模型所有的参数有个前缀name_1,现在定义的网络结构中的参数以name_2作为前缀。那么使用如下示例代码即可加载:

import tensorflow as tf
def restore(sess, ckpt_path):
  vars = tf.trainable_variables()
  vars_dict = dict()
  for v in vars:
      key = v.name.split(':')[0]
      if key.startswith("name_2/"):
          key = key.replace("name_2/", "name_1/")
      vars_dict[key] = v
  saver =tf.train.Saver(var_list=vars_dict)
  saver.restore(sess, ckpt_path)

注意: 使用上面代码时,要确保参数的shape一致,否则会无法加载参数。

如果不知道预训练的ckpt中参数名称,可以使用如下代码打印:

for name, shape in tf.train.list_variables(ckpt_path):
    print(name)


相关文章
|
3月前
|
机器学习/深度学习 算法 TensorFlow
文本分类识别Python+卷积神经网络算法+TensorFlow模型训练+Django可视化界面
文本分类识别Python+卷积神经网络算法+TensorFlow模型训练+Django可视化界面
62 0
文本分类识别Python+卷积神经网络算法+TensorFlow模型训练+Django可视化界面
|
3月前
|
机器学习/深度学习 监控 Python
tensorflow2.x多层感知机模型参数量和计算量的统计
tensorflow2.x多层感知机模型参数量和计算量的统计
|
6月前
|
TensorFlow 算法框架/工具
【tensorflow】TF1.x保存与读取.pb模型写法介绍
由于TF里面的概念比较接地气,所以用tf1.x保存.pb模型时总是怕有什么操作漏掉了,会造成保存的模型是缺少变量数据或者没有保存图,所以先明确一下:用TF1.x保存模型时只需要保存模型的输入输出的变量(多输入就保存多个),不需要保存中间的变量;用TF1.x加载模型时只需要加载保存的模型,然后读一下输入输出变量(多输入就读多个),不需要初始化(反而会重置掉变量的值)。
|
6月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
【tensorflow】连续输入的线性回归模型训练代码
  get_data函数用于生成随机的训练和验证数据集。首先使用np.random.rand生成一个形状为(10000, 10)的随机数据集,来模拟10维的连续输入,然后使用StandardScaler对数据进行标准化。再生成一个(10000,1)的target,表示最终拟合的目标分数。最后使用train_test_split函数将数据集划分为训练集和验证集。
|
6月前
|
机器学习/深度学习 算法 TensorFlow
树叶识别系统python+Django网页界面+TensorFlow+算法模型+数据集+图像识别分类
树叶识别系统python+Django网页界面+TensorFlow+算法模型+数据集+图像识别分类
131 1
|
6月前
|
机器学习/深度学习 移动开发 算法
动物识别系统python+Django网页界面+TensorFlow算法模型+数据集训练
动物识别系统python+Django网页界面+TensorFlow算法模型+数据集训练
89 0
动物识别系统python+Django网页界面+TensorFlow算法模型+数据集训练
|
6月前
|
机器学习/深度学习 算法 TensorFlow
交通标志识别系统python+TensorFlow+算法模型+Django网页+数据集
交通标志识别系统python+TensorFlow+算法模型+Django网页+数据集
62 0
|
3月前
|
机器学习/深度学习 搜索推荐 算法
推荐系统离线评估方法和评估指标,以及在推荐服务器内部实现A/B测试和解决A/B测试资源紧张的方法。还介绍了如何在TensorFlow中进行模型离线评估实践。
推荐系统离线评估方法和评估指标,以及在推荐服务器内部实现A/B测试和解决A/B测试资源紧张的方法。还介绍了如何在TensorFlow中进行模型离线评估实践。
190 0
|
6天前
|
机器学习/深度学习 TensorFlow 调度
优化TensorFlow模型:超参数调整与训练技巧
【4月更文挑战第17天】本文探讨了如何优化TensorFlow模型的性能,重点介绍了超参数调整和训练技巧。超参数如学习率、批量大小和层数对模型性能至关重要。文章提到了三种超参数调整策略:网格搜索、随机搜索和贝叶斯优化。此外,还分享了训练技巧,包括学习率调度、早停、数据增强和正则化,这些都有助于防止过拟合并提高模型泛化能力。结合这些方法,可构建更高效、健壮的深度学习模型。
|
1月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
OpenCV读取tensorflow 2.X模型的方法:将SavedModel转为frozen graph
【2月更文挑战第22天】本文介绍基于Python的tensorflow库,将tensorflow与keras训练好的SavedModel格式神经网络模型转换为frozen graph格式,从而可以用OpenCV库在C++等其他语言中将其打开的方法~
OpenCV读取tensorflow 2.X模型的方法:将SavedModel转为frozen graph