同样是保存模型,model.save()和model. save_weights ()有何区别

简介: 同样是保存模型,model.save()和model. save_weights ()有何区别

model.save()保存了模型的图结构和模型的参数,保存模型的后缀是.hdf5。


model. save_weights ()只保存了模型的参数,并没有保存模型的图结构,保存模型的后缀使用.h5。


所以使用save_weights保存的模型比使用save() 保存的模型的大小要小。同时加载模型时的方法也不同。model.save()保存了模型的图结构,直接使用load_model()方法就可加载模型然后做测试,例:


from  tensorflow.keras.models import load_model


model=load_model("my_model_.hdf5")

加载save_weights保存的模型就稍微复杂了一些,还需要再次描述模型结构信息才能加载模型。例:



def bn_prelu(x):


   x = BatchNormalization(epsilon=1e-5)(x)


   x = PReLU()(x)


   return x



def build_model(out_dims, input_shape=(norm_size, norm_size, 3)):


   inputs_dim = Input(input_shape)


   x = Conv2D(32, (3, 3), strides=(2, 2), padding='same')(inputs_dim)


   x = bn_prelu(x)


   x = Conv2D(32, (3, 3), strides=(1, 1), padding='same')(x)


   x = bn_prelu(x)


   x = MaxPooling2D(pool_size=(2, 2))(x)


   x = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(x)


   x = bn_prelu(x)


   x = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(x)


   x = bn_prelu(x)


   x = MaxPooling2D(pool_size=(2, 2))(x)


   x = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)


   x = bn_prelu(x)


   x = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)


   x = bn_prelu(x)


   x = MaxPooling2D(pool_size=(2, 2))(x)


   x = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(x)


   x = bn_prelu(x)


   x = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(x)


   x = bn_prelu(x)


   x = GlobalAveragePooling2D()(x)


   dp_1 = Dropout(0.5)(x)


   fc2 = Dense(out_dims)(dp_1)


   fc2 = Activation('softmax')(fc2) #此处注意,为sigmoid函数


   model = Model(inputs=inputs_dim, outputs=fc2)


   return model


model=build_model(labelnum)


model. load_weights(“my_model_.h5”);


目录
相关文章
|
5月前
|
TensorFlow 算法框架/工具 Python
【Tensorflow】Found unexpected keys that do not correspond to any Model output: dict_keys([‘model_outp
文章讨论了在使用Tensorflow 2.3时遇到的错误信息:"Found unexpected keys that do not correspond to any Model output: dict_keys(['model_output']). Expected: ['dense']"。这个问题通常发生在模型的输出层命名与model.fit_generator的生成器函数中返回的值的键不匹配时。
58 1
|
5月前
|
JavaScript 开发者
v-model学习
v-model学习
93 0
|
5月前
|
TensorFlow API 算法框架/工具
【Tensorflow+keras】解决使用model.load_weights时报错 ‘str‘ object has no attribute ‘decode‘
python 3.6,Tensorflow 2.0,在使用Tensorflow 的keras API,加载权重模型时,报错’str’ object has no attribute ‘decode’
74 0
|
8月前
|
PyTorch 算法框架/工具
pytorch - swa_model模型保存的问题
pytorch - swa_model模型保存的问题
114 0
|
7月前
|
存储 机器学习/深度学习 PyTorch
【从零开始学习深度学习】19. Pytorch中如何存储与读取模型:torch.save、torch.load与state_dict对象
【从零开始学习深度学习】19. Pytorch中如何存储与读取模型:torch.save、torch.load与state_dict对象
|
7月前
|
JavaScript 前端开发
v-model
v-model
56 0
|
8月前
添加数据:(model.py)
添加数据:(model.py)。
34 2
|
8月前
|
机器学习/深度学习 PyTorch 算法框架/工具
通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()
通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()
106 0
|
8月前
|
JavaScript
v-model和:model的区别
v-model和:model的区别
281 0
|
API 数据格式
TensorFlow2._:model.summary() Output Shape为multiple解决方法
TensorFlow2._:model.summary() Output Shape为multiple解决方法
292 0
TensorFlow2._:model.summary() Output Shape为multiple解决方法