举例实现
(1)模型实现
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras import *
import json
import numpy
# 这个类解决json.dump(dict)时报错Object of type 'float32' is not JSON serializable
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (numpy.int_, numpy.intc, numpy.intp, numpy.int8,
numpy.int16, numpy.int32, numpy.int64, numpy.uint8,
numpy.uint16, numpy.uint32, numpy.uint64)):
return int(obj)
elif isinstance(obj, (numpy.float_, numpy.float16, numpy.float32,numpy.float64)):
return float(obj)
elif isinstance(obj, (numpy.ndarray,)):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
def main()
# 搭建模型
inputs = tf.keras.layers.Input(shape=(3,))
d = tf.keras.layers.Dense(2, name='out')
output_1 = d(inputs)
output_2 = d(inputs)
model = tf.keras.models.Model(
inputs=inputs, outputs=[output_1, output_2])
model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"])
# 保存模型权重
checkpoint = callbacks.ModelCheckpoint('real_weight_10.tf',save_format='tf', monitor='val_acc',verbose=0, save_best_only=True, mode='min', save_weights_only=True)
history = model.fit(x, (y, y)))
# 画loss曲线
epochs=range(len(history['bit_err']))
plt.figure()
plt.plot(epochs,history['bit_err'],'b',label='Training bit_error')
plt.plot(epochs,history['val_bit_err'],'r',label='Validation bit_error')
plt.title('Traing and Validation bit_error')
plt.legend()
plt.savefig('figure/model_bit_err_SNR10.jpg')
plot.show()
plt.figure()
plt.plot(epochs,history['loss'],'b',label='Training loss')
plt.plot(epochs,history['val_loss'],'r',label='Validation val_loss')
plt.title('Traing and Validation loss')
plt.legend()
plt.savefig('figure/model_loss_SNR10.jpg')
plt.show()
# 保存loss值
history_dict = history.history
json.dump(history_dict, open('model_history/history.json', 'w'),cls=NumpyEncoder)
if __name__ == '__main__':
# freeze_support() here if program needs to be frozen
main()
(2)单独加载模型loss值
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
import json
history = json.load(open('model_history/history.json', 'r'))
epochs=range(len(history['bit_err']))
plt.figure()
plt.plot(epochs,history['bit_err'],'b',label='Training bit_error')
plt.plot(epochs,history['val_bit_err'],'r',label='Validation bit_error')
plt.title('Traing and Validation bit_error')
plt.legend()
# plt.savefig('figure/model_bit_err_SNR10.jpg')
plot.show()
plt.figure()
plt.plot(epochs,history['loss'],'b',label='Training loss')
plt.plot(epochs,history['val_loss'],'r',label='Validation val_loss')
plt.title('Traing and Validation loss')
plt.legend()
# plt.savefig('figure/model_loss_SNR10.jpg')
plt.show()