TF:利用TF的train.Saver将训练好的W、b模型文件保存+新建载入刚训练好模型(用于以后预测新的数据)

简介: TF:利用TF的train.Saver将训练好的W、b模型文件保存+新建载入刚训练好模型(用于以后预测新的数据)

输出结果

image.png

代码设计

import tensorflow as tf

import numpy as np

W = tf.Variable([[2,1,8],[1,2,5]], dtype=tf.float32, name='weights')

b = tf.Variable([[1,2,5]], dtype=tf.float32, name='biases')

 

init= tf.global_variables_initializer()  

 

saver = tf.train.Saver()    

 

with tf.Session() as sess:  

   sess.run(init)

   save_path = saver.save(sess, "niu/save_net.ckpt")

   print("Save to path: ", save_path)



#TF:利用TF的train.Saver载入曾经训练好的variables(W、b)以供预测新的数据

import tensorflow as tf

import numpy as np

W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")

b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")

 

saver = tf.train.Saver()

with tf.Session() as sess:

   saver.restore(sess, "niu/save_net.ckpt")

   print("weights:", sess.run(W))

   print("biases:", sess.run(b))


相关文章
|
TensorFlow 算法框架/工具
tensorflow/train训练指令
tensorflow/train训练指令
79 0
|
3月前
|
计算机视觉
数据集学习笔记(三):COCO创建dataloader用于训练
如何使用COCO数据集创建dataloader进行训练,包括安装环境、加载数据集代码、定义数据转换、创建数据集对象以及创建dataloader。
66 5
|
5月前
|
API 算法框架/工具
【Tensorflow+keras】使用keras API保存模型权重、plot画loss损失函数、保存训练loss值
使用keras API保存模型权重、plot画loss损失函数、保存训练loss值
42 0
|
8月前
|
PyTorch 算法框架/工具
pytorch - swa_model模型保存的问题
pytorch - swa_model模型保存的问题
114 0
|
Java TensorFlow 算法框架/工具
【tensorflow】TF1.x保存.pb模型 解决模型越训练越大问题
在上一篇博客【tensorflow】TF1.x保存与读取.pb模型写法介绍介绍的保存.pb模型方法中,保存的是模型训练过程中所有的参数,而且训练越久,最终保存的模型就越大。我的模型只有几千参数,可是最终保存的文件有1GB。。。。
|
机器学习/深度学习 存储 PyTorch
怎么调用pytorch中mnist数据集
怎么调用pytorch中mnist数据集
237 0
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch中如何使用DataLoader对数据集进行批训练
Pytorch中如何使用DataLoader对数据集进行批训练
148 0
|
PyTorch 算法框架/工具
【PyTorch】自定义数据集处理/dataset/DataLoader等
【PyTorch】自定义数据集处理/dataset/DataLoader等
225 0
|
数据采集 并行计算 PyTorch
【目标检测之数据集加载】利用DataLoader加载已预处理后的数据集【附代码】
在前一篇文章中,已经通过继承Dataset预处理自己的数据集 ,接下来就是使用pytorch提供的DataLoader函数加载数据集。
657 0
【目标检测之数据集加载】利用DataLoader加载已预处理后的数据集【附代码】
|
PyTorch 算法框架/工具
【pytorch】pytorch代码中实现MNIST、cifar10等数据集本地读取
pytorch代码中实现MNIST、cifar10等数据集本地读取
【pytorch】pytorch代码中实现MNIST、cifar10等数据集本地读取

相关实验场景

更多