1 环境
python 3.6
Tensorflow 2.0
在使用Tensorflow 的keras API,加载权重模型时,报错’str’ object has no attribute ‘decode’
from tensorflow.keras import *
from tensorflow.keras.layers import *
model = testmodel()
model = Model(model.input, Dense(1, activation='linear', kernel_initializer='normal')(model.layers[-2].output))
checkpoint = callbacks.ModelCheckpoint('current_best.h5')
history = model.fit_generator(
training_gen(1000, SNRdb),
steps_per_epoch=50,
epochs=500,
validation_data=validation_gen(1000, SNRdb),
validation_steps=1,
callbacks=[checkpoint],
verbose=2)
model = model.load_weights('current_best.h5')
2 解决办法
在保存模型的时候,模型格式为.tf
checkpoint = callbacks.ModelCheckpoint('current_best.tf',save_format='tf')
加载模型的时,改为
model = model.load_weights('current_best.tf')