【Tensorflow+keras】使用keras API保存模型权重、plot画loss损失函数、保存训练loss值

简介: 使用keras API保存模型权重、plot画loss损失函数、保存训练loss值

举例实现

(1)模型实现

import tensorflow  as tf
from tensorflow.keras.layers import *
from tensorflow.keras import *
import json
import numpy
# 这个类解决json.dump(dict)时报错Object of type 'float32' is not JSON serializable
class NumpyEncoder(json.JSONEncoder):  
    def default(self, obj):  
        if isinstance(obj, (numpy.int_, numpy.intc, numpy.intp, numpy.int8,  
            numpy.int16, numpy.int32, numpy.int64, numpy.uint8,  
            numpy.uint16, numpy.uint32, numpy.uint64)):  
            return int(obj)  
        elif isinstance(obj, (numpy.float_, numpy.float16, numpy.float32,numpy.float64)):  
            return float(obj)  
        elif isinstance(obj, (numpy.ndarray,)):  
            return obj.tolist()  
        return json.JSONEncoder.default(self, obj)  
def main()
    # 搭建模型
    inputs = tf.keras.layers.Input(shape=(3,))
    d = tf.keras.layers.Dense(2, name='out')
    output_1 = d(inputs)
    output_2 = d(inputs)
    model = tf.keras.models.Model(
    inputs=inputs, outputs=[output_1, output_2])
    model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"])
    # 保存模型权重
    checkpoint = callbacks.ModelCheckpoint('real_weight_10.tf',save_format='tf', monitor='val_acc',verbose=0, save_best_only=True, mode='min', save_weights_only=True)
    history = model.fit(x, (y, y)))
    # 画loss曲线
    epochs=range(len(history['bit_err']))
    plt.figure()
    plt.plot(epochs,history['bit_err'],'b',label='Training bit_error')
    plt.plot(epochs,history['val_bit_err'],'r',label='Validation bit_error')
    plt.title('Traing and Validation bit_error')
    plt.legend()
    plt.savefig('figure/model_bit_err_SNR10.jpg')
    plot.show()
    plt.figure()
    plt.plot(epochs,history['loss'],'b',label='Training loss')
    plt.plot(epochs,history['val_loss'],'r',label='Validation val_loss')
    plt.title('Traing and Validation loss')
    plt.legend()
    plt.savefig('figure/model_loss_SNR10.jpg')
    plt.show()
    # 保存loss值
    history_dict = history.history
    json.dump(history_dict, open('model_history/history.json', 'w'),cls=NumpyEncoder)

if __name__ == '__main__':
   # freeze_support() here if program needs to be frozen
    main()

(2)单独加载模型loss值

import numpy as np 

import scipy.io as sio
import matplotlib.pyplot as plt
import json

history = json.load(open('model_history/history.json', 'r'))
epochs=range(len(history['bit_err']))
plt.figure()
plt.plot(epochs,history['bit_err'],'b',label='Training bit_error')
plt.plot(epochs,history['val_bit_err'],'r',label='Validation bit_error')
plt.title('Traing and Validation bit_error')
plt.legend()
# plt.savefig('figure/model_bit_err_SNR10.jpg')
plot.show()

plt.figure()
plt.plot(epochs,history['loss'],'b',label='Training loss')
plt.plot(epochs,history['val_loss'],'r',label='Validation val_loss')
plt.title('Traing and Validation loss')
plt.legend()
# plt.savefig('figure/model_loss_SNR10.jpg')
plt.show()
目录
相关文章
|
1月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
将Keras训练好的.hdf5模型转换为TensorFlow的.pb模型,然后再转换为TensorRT支持的.uff格式,并提供了转换代码和测试步骤。
82 3
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
|
1月前
|
机器学习/深度学习 TensorFlow API
使用 TensorFlow 和 Keras 构建图像分类器
【10月更文挑战第2天】使用 TensorFlow 和 Keras 构建图像分类器
|
2月前
|
人工智能 Serverless API
一键服务化:从魔搭开源模型到OpenAI API服务
在多样化大模型的背后,OpenAI得益于在领域的先发优势,其API接口今天也成为了业界的一个事实标准。
一键服务化:从魔搭开源模型到OpenAI API服务
|
1月前
|
机器学习/深度学习 移动开发 TensorFlow
深度学习之格式转换笔记(四):Keras(.h5)模型转化为TensorFlow(.pb)模型
本文介绍了如何使用Python脚本将Keras模型转换为TensorFlow的.pb格式模型,包括加载模型、重命名输出节点和量化等步骤,以便在TensorFlow中进行部署和推理。
76 0
|
3月前
|
UED 开发工具 iOS开发
Uno Platform大揭秘:如何在你的跨平台应用中,巧妙融入第三方库与服务,一键解锁无限可能,让应用功能飙升,用户体验爆棚!
【8月更文挑战第31天】Uno Platform 让开发者能用同一代码库打造 Windows、iOS、Android、macOS 甚至 Web 的多彩应用。本文介绍如何在 Uno Platform 中集成第三方库和服务,如 Mapbox 或 Google Maps 的 .NET SDK,以增强应用功能并提升用户体验。通过 NuGet 安装所需库,并在 XAML 页面中添加相应控件,即可实现地图等功能。尽管 Uno 平台减少了平台差异,但仍需关注版本兼容性和性能问题,确保应用在多平台上表现一致。掌握正确方法,让跨平台应用更出色。
50 0
|
3月前
|
Apache 开发者 Java
Apache Wicket揭秘:如何巧妙利用模型与表单机制,实现Web应用高效开发?
【8月更文挑战第31天】本文深入探讨了Apache Wicket的模型与表单处理机制。Wicket作为一个组件化的Java Web框架,提供了多种模型实现,如CompoundPropertyModel等,充当组件与数据间的桥梁。文章通过示例介绍了模型创建及使用方法,并详细讲解了表单组件、提交处理及验证机制,帮助开发者更好地理解如何利用Wicket构建高效、易维护的Web应用程序。
46 0
|
3月前
|
机器学习/深度学习 API TensorFlow
深入解析TensorFlow 2.x中的Keras API:快速搭建深度学习模型的实战指南
【8月更文挑战第31天】本文通过搭建手写数字识别模型的实例,详细介绍了如何利用TensorFlow 2.x中的Keras API简化深度学习模型构建流程。从环境搭建到数据准备,再到模型训练与评估,展示了Keras API的强大功能与易用性,适合初学者快速上手。通过简单的代码,即可完成卷积神经网络的构建与训练,显著降低了深度学习的技术门槛。无论是新手还是专业人士,都能从中受益,高效实现模型开发。
29 0
|
3月前
|
SQL Shell API
python Django教程 之 模型(数据库)、自定义Field、数据表更改、QuerySet API
python Django教程 之 模型(数据库)、自定义Field、数据表更改、QuerySet API
|
3月前
|
TensorFlow 算法框架/工具
【Tensorflow+Keras】学习率指数、分段、逆时间、多项式衰减及自定义学习率衰减的完整实例
使用Tensorflow和Keras实现学习率衰减的完整实例,包括指数衰减、分段常数衰减、多项式衰减、逆时间衰减以及如何通过callbacks自定义学习率衰减策略。
63 0
|
4天前
|
JSON API 数据格式
淘宝 / 天猫官方商品 / 订单订单 API 接口丨商品上传接口对接步骤
要对接淘宝/天猫官方商品或订单API,需先注册淘宝开放平台账号,创建应用获取App Key和App Secret。之后,详细阅读API文档,了解接口功能及权限要求,编写认证、构建请求、发送请求和处理响应的代码。最后,在沙箱环境中测试与调试,确保API调用的正确性和稳定性。