TF:利用TF的train.Saver将训练好的W、b模型文件保存+新建载入刚训练好模型(用于以后预测新的数据)

简介: TF:利用TF的train.Saver将训练好的W、b模型文件保存+新建载入刚训练好模型(用于以后预测新的数据)

输出结果

image.png

代码设计

import tensorflow as tf

import numpy as np

W = tf.Variable([[2,1,8],[1,2,5]], dtype=tf.float32, name='weights')

b = tf.Variable([[1,2,5]], dtype=tf.float32, name='biases')

 

init= tf.global_variables_initializer()  

 

saver = tf.train.Saver()    

 

with tf.Session() as sess:  

   sess.run(init)

   save_path = saver.save(sess, "niu/save_net.ckpt")

   print("Save to path: ", save_path)



#TF:利用TF的train.Saver载入曾经训练好的variables(W、b)以供预测新的数据

import tensorflow as tf

import numpy as np

W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")

b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")

 

saver = tf.train.Saver()

with tf.Session() as sess:

   saver.restore(sess, "niu/save_net.ckpt")

   print("weights:", sess.run(W))

   print("biases:", sess.run(b))


相关文章
|
5月前
|
机器学习/深度学习
大模型训练loss突刺原因和解决办法
【1月更文挑战第19天】大模型训练loss突刺原因和解决办法
893 1
大模型训练loss突刺原因和解决办法
|
2月前
|
API 算法框架/工具
【Tensorflow+keras】使用keras API保存模型权重、plot画loss损失函数、保存训练loss值
使用keras API保存模型权重、plot画loss损失函数、保存训练loss值
22 0
|
5月前
|
PyTorch 算法框架/工具
pytorch - swa_model模型保存的问题
pytorch - swa_model模型保存的问题
82 0
|
12月前
|
Java TensorFlow 算法框架/工具
【tensorflow】TF1.x保存.pb模型 解决模型越训练越大问题
在上一篇博客【tensorflow】TF1.x保存与读取.pb模型写法介绍介绍的保存.pb模型方法中,保存的是模型训练过程中所有的参数,而且训练越久,最终保存的模型就越大。我的模型只有几千参数,可是最终保存的文件有1GB。。。。
|
开发者
onnx 模型修改
已经生成的onnx 模型删除后处理sigmoid mul等层
307 0
|
数据采集 并行计算 PyTorch
【目标检测之数据集加载】利用DataLoader加载已预处理后的数据集【附代码】
在前一篇文章中,已经通过继承Dataset预处理自己的数据集 ,接下来就是使用pytorch提供的DataLoader函数加载数据集。
582 0
【目标检测之数据集加载】利用DataLoader加载已预处理后的数据集【附代码】
|
TensorFlow 算法框架/工具
TensorFlow指定每个epoch验证多少个批次数据集
TensorFlow指定每个epoch验证多少个批次数据集
133 0
|
缓存 NoSQL MongoDB
TensorFlow2.0(10):加载自定义图片数据集到Dataset
TensorFlow2.0(10):加载自定义图片数据集到Dataset
|
移动开发 算法 算法框架/工具
DL之DCGAN(Keras框架):基于keras框架利用深度卷积对抗网络DCGAN算法对MNIST数据集实现图像生成(保存h5模型→加载模型)
DL之DCGAN(Keras框架):基于keras框架利用深度卷积对抗网络DCGAN算法对MNIST数据集实现图像生成(保存h5模型→加载模型)
DL之DCGAN(Keras框架):基于keras框架利用深度卷积对抗网络DCGAN算法对MNIST数据集实现图像生成(保存h5模型→加载模型)
|
机器学习/深度学习 TensorFlow 算法框架/工具
TF之DNN:利用DNN【784→500→10】对MNIST手写数字图片识别数据集(TF自带函数下载)预测(98%)+案例理解DNN过程
TF之DNN:利用DNN【784→500→10】对MNIST手写数字图片识别数据集(TF自带函数下载)预测(98%)+案例理解DNN过程
TF之DNN:利用DNN【784→500→10】对MNIST手写数字图片识别数据集(TF自带函数下载)预测(98%)+案例理解DNN过程