我们定义的一个上采样的计算块
class Connect(tf.keras.layers.Layer): def __init__(self, filters=256, name='Connect', **kwargs): super(Connect, self).__init__(name=name, **kwargs) self.Conv_Transpose = tf.keras.layers.Convolution2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same", activation="relu") self.conv_out = tf.keras.layers.Conv2D(filters=filters, kernel_size=3, padding="same", activation="relu") def call(self, inputs): x = self.Conv_Transpose(inputs) return self.conv_out(x)
为了提取更为一般的特征我们将vgg13
网络的"block5_conv3
",“block4_conv3
”,“block3_conv3
”,等不同深度层输出结果进行了跳级(skip)连接。
layer_names = ["block5_conv3", "block4_conv3", "block3_conv3", "block5_pool"] #得到4个输出 layers_out = [vgg16.get_layer(layer_name).output for layer_name in layer_names] multi_out_model = tf.keras.models.Model(inputs = vgg16.input, outputs = layers_out) multi_out_model.trainable = False #创建输入 inputs = tf.keras.layers.Input(shape=(224,224,3)) out_block5_conv3,out_block4_conv3,out_block3_conv3,out = multi_out_model(inputs) print(out_block5_conv3.shape) x1 = Connect(512,name="connect_1")(out) x1 = tf.add(x1,out_block5_conv3)#元素对应相加 x2 = Connect(512,name="connect_2")(x1) x2 = tf.add(x2,out_block4_conv3)#元素对应相加 x3 = Connect(256,name="connect_3")(x2) x3 = tf.add(x3,out_block3_conv3)#元素对应相加 x4 = Connect(128,name="connect_4")(x3) prediction = tf.keras.layers.Convolution2DTranspose(filters=3, kernel_size=3, strides=2, padding="same", activation="softmax")(x4) model = tf.keras.models.Model(inputs=inputs,outputs=prediction) model.summary()
Model: "model_5" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_6 (InputLayer) [(None, 224, 224, 3) 0 __________________________________________________________________________________________________ model_4 (Model) [(None, 14, 14, 512) 14714688 input_6[0][0] __________________________________________________________________________________________________ connect_1 (Connect) (None, 14, 14, 512) 4719616 model_4[1][3] __________________________________________________________________________________________________ tf_op_layer_Add_6 (TensorFlowOp [(None, 14, 14, 512) 0 connect_1[0][0] model_4[1][0] __________________________________________________________________________________________________ connect_2 (Connect) (None, 28, 28, 512) 4719616 tf_op_layer_Add_6[0][0] __________________________________________________________________________________________________ tf_op_layer_Add_7 (TensorFlowOp [(None, 28, 28, 512) 0 connect_2[0][0] model_4[1][1] __________________________________________________________________________________________________ connect_3 (Connect) (None, 56, 56, 256) 1769984 tf_op_layer_Add_7[0][0] __________________________________________________________________________________________________ tf_op_layer_Add_8 (TensorFlowOp [(None, 56, 56, 256) 0 connect_3[0][0] model_4[1][2] __________________________________________________________________________________________________ connect_4 (Connect) (None, 112, 112, 128 442624 tf_op_layer_Add_8[0][0] __________________________________________________________________________________________________ conv2d_transpose_14 (Conv2DTran (None, 224, 224, 3) 3459 connect_4[0][0] ================================================================================================== Total params: 26,369,987 Trainable params: 11,655,299 Non-trainable params: 14,714,688 __________________________________________________________________________________________________
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), loss="sparse_categorical_crossentropy", metrics=["acc"] ) steps_per_eooch = train_count//batch_size validation_steps = test_count//batch_size history = model.fit(train_ds, epochs=3, steps_per_epoch=steps_per_eooch, validation_data=test_ds, validation_steps=validation_steps)
Epoch 1/3 739/739 [==============================] - 293s 396ms/step - loss: 0.3794 - acc: 0.8461 - val_loss: 0.2967 - val_acc: 0.8797 Epoch 2/3 739/739 [==============================] - 292s 395ms/step - loss: 0.2823 - acc: 0.8848 - val_loss: 0.2743 - val_acc: 0.8897 Epoch 3/3 739/739 [==============================] - 292s 395ms/step - loss: 0.2572 - acc: 0.8947 - val_loss: 0.2631 - val_acc: 0.8935
4. 模型评估
从上面的训练中,我们迭代了三次就达到了,达到90%的准确率,从整体说效果是不错的,下面我们可视化一个图像,观察具体的预测效果。
for image,mask in test_ds.take(1): pred_mask = model.predict(image) pred_mask = tf.argmax(pred_mask,axis=-1) pred_mask = pred_mask[...,tf.newaxis] plt.figure(figsize=(10,10)) plt.subplot(1,3,1) plt.imshow(tf.keras.preprocessing.image.array_to_img(image[0])) plt.subplot(1,3,2) plt.imshow(tf.keras.preprocessing.image.array_to_img(mask[0])) plt.subplot(1,3,3) plt.imshow(tf.keras.preprocessing.image.array_to_img(pred_mask[0]))