【Tensorflow+keras】使用keras API保存模型权重、plot画loss损失函数、保存训练loss值

简介: 使用keras API保存模型权重、plot画loss损失函数、保存训练loss值

举例实现

(1)模型实现

import tensorflow  as tf
from tensorflow.keras.layers import *
from tensorflow.keras import *
import json
import numpy
# 这个类解决json.dump(dict)时报错Object of type 'float32' is not JSON serializable
class NumpyEncoder(json.JSONEncoder):  
    def default(self, obj):  
        if isinstance(obj, (numpy.int_, numpy.intc, numpy.intp, numpy.int8,  
            numpy.int16, numpy.int32, numpy.int64, numpy.uint8,  
            numpy.uint16, numpy.uint32, numpy.uint64)):  
            return int(obj)  
        elif isinstance(obj, (numpy.float_, numpy.float16, numpy.float32,numpy.float64)):  
            return float(obj)  
        elif isinstance(obj, (numpy.ndarray,)):  
            return obj.tolist()  
        return json.JSONEncoder.default(self, obj)  
def main()
    # 搭建模型
    inputs = tf.keras.layers.Input(shape=(3,))
    d = tf.keras.layers.Dense(2, name='out')
    output_1 = d(inputs)
    output_2 = d(inputs)
    model = tf.keras.models.Model(
    inputs=inputs, outputs=[output_1, output_2])
    model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"])
    # 保存模型权重
    checkpoint = callbacks.ModelCheckpoint('real_weight_10.tf',save_format='tf', monitor='val_acc',verbose=0, save_best_only=True, mode='min', save_weights_only=True)
    history = model.fit(x, (y, y)))
    # 画loss曲线
    epochs=range(len(history['bit_err']))
    plt.figure()
    plt.plot(epochs,history['bit_err'],'b',label='Training bit_error')
    plt.plot(epochs,history['val_bit_err'],'r',label='Validation bit_error')
    plt.title('Traing and Validation bit_error')
    plt.legend()
    plt.savefig('figure/model_bit_err_SNR10.jpg')
    plot.show()
    plt.figure()
    plt.plot(epochs,history['loss'],'b',label='Training loss')
    plt.plot(epochs,history['val_loss'],'r',label='Validation val_loss')
    plt.title('Traing and Validation loss')
    plt.legend()
    plt.savefig('figure/model_loss_SNR10.jpg')
    plt.show()
    # 保存loss值
    history_dict = history.history
    json.dump(history_dict, open('model_history/history.json', 'w'),cls=NumpyEncoder)

if __name__ == '__main__':
   # freeze_support() here if program needs to be frozen
    main()

(2)单独加载模型loss值

import numpy as np 

import scipy.io as sio
import matplotlib.pyplot as plt
import json

history = json.load(open('model_history/history.json', 'r'))
epochs=range(len(history['bit_err']))
plt.figure()
plt.plot(epochs,history['bit_err'],'b',label='Training bit_error')
plt.plot(epochs,history['val_bit_err'],'r',label='Validation bit_error')
plt.title('Traing and Validation bit_error')
plt.legend()
# plt.savefig('figure/model_bit_err_SNR10.jpg')
plot.show()

plt.figure()
plt.plot(epochs,history['loss'],'b',label='Training loss')
plt.plot(epochs,history['val_loss'],'r',label='Validation val_loss')
plt.title('Traing and Validation loss')
plt.legend()
# plt.savefig('figure/model_loss_SNR10.jpg')
plt.show()
目录
相关文章
|
5天前
|
人工智能 自然语言处理 前端开发
【2025.3.08更新】Linkreate wordpressAI智能插件|自动生成SEO文章/图片/视频+长尾词优化 内置DeepSeek多模型支持与API扩展
Linkreate WordPress AI插件提供强大的自动化文章生成、SEO优化、关键词管理和内容采集功能。它能根据关键词自动生成高质量文章,支持多语言和批量生成,内置长尾关键词生成工具,并可定时自动发布文章。插件还集成了多种AI服务,支持前端AI客服窗口及媒体生成,帮助用户高效管理网站内容,提升SEO效果。
|
10天前
|
人工智能 自然语言处理 前端开发
Linkreate wordpressAI智能插件|自动生成SEO文章/图片/视频+长尾词优化 内置DeepSeek多模型支持与API扩展
Linkreate WordPress AI插件提供强大的文章生成与优化功能,支持自动化生成高质量文章、批量生成、SEO优化及双标题定制。关键词生成管理方面,可批量生成长尾关键词并自定义参数。内容采集功能支持单篇和批量采集指定网站内容,可视化规则生成器方便使用。定时任务实现全自动文章生成,24小时稳定运行。API集成兼容多种AI服务,如DeepSeek、OpenAI等,并支持前端AI客服窗口。媒体生成功能包括自动为文章生成图片和短视频,提升内容丰富度。官网提供插件演示及下载:[https://idc.xymww.com/](https://idc.xymww.com/)
|
13天前
|
人工智能 物联网 API
又又又上新啦!魔搭免费模型推理API支持DeepSeek-R1,Qwen2.5-VL,Flux.1 dev及Lora等
又又又上新啦!魔搭免费模型推理API支持DeepSeek-R1,Qwen2.5-VL,Flux.1 dev及Lora等
|
13天前
|
人工智能 自然语言处理 API
零门槛,即刻拥有DeepSeek-R1满血版——调用API及部署各尺寸模型
本文介绍了如何利用阿里云技术快速部署和使用DeepSeek系列模型,涵盖满血版API调用和云端部署两种方案。DeepSeek在数学、代码和自然语言处理等复杂任务中表现出色,支持私有化部署和企业级加密,确保数据安全。通过详细的步骤和代码示例,帮助开发者轻松上手,提升工作效率和模型性能。解决方案链接:[阿里云DeepSeek方案](https://www.aliyun.com/solution/tech-solution/deepseek-r1-for-platforms?utm_content=g_1000401616)。
零门槛,即刻拥有DeepSeek-R1满血版——调用API及部署各尺寸模型
|
25天前
|
机器学习/深度学习 PyTorch TensorFlow
深度学习工具和框架详细指南:PyTorch、TensorFlow、Keras
在深度学习的世界中,PyTorch、TensorFlow和Keras是最受欢迎的工具和框架,它们为研究者和开发者提供了强大且易于使用的接口。在本文中,我们将深入探索这三个框架,涵盖如何用它们实现经典深度学习模型,并通过代码实例详细讲解这些工具的使用方法。
|
2月前
|
机器学习/深度学习 人工智能 安全
GLM-Zero:智谱AI推出与 OpenAI-o1-Preview 旗鼓相当的深度推理模型,开放在线免费使用和API调用
GLM-Zero 是智谱AI推出的深度推理模型,专注于提升数理逻辑、代码编写和复杂问题解决能力,支持多模态输入与完整推理过程输出。
262 24
GLM-Zero:智谱AI推出与 OpenAI-o1-Preview 旗鼓相当的深度推理模型,开放在线免费使用和API调用
|
2月前
|
自然语言处理 安全 API
API First:模型驱动的阿里云API保障体系
本文介绍了阿里云在API设计和管理方面的最佳实践。首先,通过API First和模型驱动的方式确保API的安全、稳定和效率。其次,分享了阿里云内部如何使用CloudSpec IDL语言及配套工具保障API质量,并实现自动化生成多语言SDK等工具。接着,描述了API从设计到上线的完整生命周期,包括规范校验、企业级能力接入、测试和发布等环节。最后,展望了未来,强调了持续提升API质量和开源CloudSpec IDL的重要性,以促进社区共建更好的API生态。
|
7天前
|
存储 缓存 监控
如何高效爬取天猫商品数据?官方API与非官方接口全解析
本文介绍两种天猫商品数据爬取方案:官方API和非官方接口。官方API合法合规,适合企业长期使用,需申请企业资质;非官方接口适合快速验证需求,但需应对反爬机制。详细内容涵盖开发步骤、Python实现示例、反爬策略、数据解析与存储、注意事项及扩展应用场景。推荐工具链包括Playwright、aiohttp、lxml等。如需进一步帮助,请联系作者。
|
8天前
|
机器学习/深度学习 JSON 算法
淘宝拍立淘按图搜索API接口系列的应用与数据解析
淘宝拍立淘按图搜索API接口是阿里巴巴旗下淘宝平台提供的一项基于图像识别技术的创新服务。以下是对该接口系列的应用与数据解析的详细分析
|
21天前
|
JSON API 数据格式
淘宝商品评论数据API接口详解及JSON示例返回
淘宝商品评论数据API接口是淘宝开放平台提供的一项服务,旨在帮助开发者通过编程方式获取淘宝商品的评论数据。这些数据包括评论内容、评论时间、评论者信息、评分等,对于电商分析、用户行为研究、竞品分析等领域都具有极高的价值。

热门文章

最新文章