1. 导入相关包
import tensorflow as tf import matplotlib.pyplot as plt %matplotlib inline from lxml import etree import numpy as np import glob import matplotlib.patches as Rectangle print(tf.__version__) tf.test.is_gpu_available()
2.0.0 True
2. 数据的预处理
2.1 创建输入管道
images = glob.glob("./Image_location/images/*.jpg") xmls = glob.glob("./Image_location/annotations/xmls/*.xml") #获取文件的名称 names = [x.split("\\")[-1].split(".xml")[0] for x in xmls] imgs_train = [img for img in images if img.split("\\")[-1].split(".jpg")[0] in names] imgs_test = [img for img in images if img.split("\\")[-1].split(".jpg")[0] not in names] #对其进行排序 imgs_train.sort(key=lambda x :x.split("\\")[-1].split(".jpg")[0]) xmls.sort(key=lambda x :x.split("\\")[-1].split(".xml")[0]) #排序后确定类别 names = [x.split("\\")[-1].split(".xml")[0] for x in xmls] class_label = ["cat","dog"] class_label_index = dict((name,index) for index,name in enumerate(class_label)) label = [class_label_index[class_label[0]] if label.istitle() else class_label_index[class_label[1]] for label in names]
对标签数据读入并封装
def to_labels(path): xml = open("{}".format(path)).read() sel = etree.HTML(xml) width = int(sel.xpath("//size/width/text()")[0]) height = int(sel.xpath("//size/height/text()")[0]) xmin = int(sel.xpath("//bndbox/xmin/text()")[0]) xmax = int(sel.xpath("//bndbox/xmax/text()")[0]) ymin = int(sel.xpath("//bndbox/ymin/text()")[0]) ymax = int(sel.xpath("//bndbox/ymax/text()")[0]) return [xmin/width,ymin/height,xmax/width,ymax/height] #读取位置信息 labels = [to_labels(path) for path in xmls] out1,out2,out3,out4 = list(zip(*labels)) out1 = np.array(out1) out2 = np.array(out2) out3 = np.array(out3) out4 = np.array(out4) label = np.array(label) label_datasets = tf.data.Dataset.from_tensor_slices((out1,out2,out3,out4,label))
对所在路径下的图片进行读取
def loda_image(path): img = tf.io.read_file(path) img = tf.image.decode_jpeg(img,channels=3) img = tf.image.resize(img,(224,224)) img = img/127.5 - 1 #规划到-1到1之间 return img image_dataset = tf.data.Dataset.from_tensor_slices(imgs_train) AUTOTUNE = tf.data.experimental.AUTOTUNE image_dataset = image_dataset.map(loda_image,num_parallel_calls=AUTOTUNE) dataset = tf.data.Dataset.zip((image_dataset,label_datasets))
2.2 设置训练集与测试集
#%%设置训练数据和验证集数据的大小 test_count = int(len(imgs_train)*0.2) train_count = len(imgs_train) - test_count print(test_count,train_count) #跳过test_count个 dataset = dataset.shuffle(buffer_size=len(imgs_train)) train_dataset = dataset.skip(test_count) test_dataset = dataset.take(test_count) batch_size = 16 # 设置一个和数据集大小一致的 shuffle buffer size(随机缓冲区大小)以保证数据被充分打乱。 train_ds = train_dataset.shuffle(buffer_size=train_count).repeat().batch(batch_size) train_ds = train_ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) test_ds = test_dataset.batch(batch_size)
737 2949 • 1
我们打印一个图片看其效果
from matplotlib.patches import Rectangle for img,label in train_ds.take(1): plt.imshow(tf.keras.preprocessing.image.array_to_img(img[0])) out1,out2,out3,out4,out5= label xmin,ymin,xmax,ymax = out1[0].numpy()*224,out2[0].numpy()*224,out3[0].numpy()*224,out4[0].numpy()*224 #给定左下角坐标 rect = Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),fill=False,color = "red") ax = plt.gca() ax.axes.add_patch(rect) plt.title((class_label[out5[0]]).title()) plt.show()
3. 创建模型
我们采用了Xception预训练模型进行构建网络,并采用了上节所用的多输出模式,分别输出类别和位置信息。
xcpetion = tf.keras.applications.Xception(input_shape=(224, 224, 3), include_top=False, weights='imagenet') inputs = tf.keras.layers.Input(shape=(224,224,3)) x = xcpetion(inputs) x = tf.keras.layers.GlobalAveragePooling2D()(x) x1 = tf.keras.layers.Dense(2048, activation='relu')(x) x1 = tf.keras.layers.Dense(256, activation='relu')(x1) out1 = tf.keras.layers.Dense(1,name="out1")(x1) out2 = tf.keras.layers.Dense(1,name="out2")(x1) out3 = tf.keras.layers.Dense(1,name="out3")(x1) out4 = tf.keras.layers.Dense(1,name="out4")(x1) x2 = tf.keras.layers.Dense(1024, activation='relu')(x) x2 = tf.keras.layers.Dense(256, activation='relu')(x2) out_class = tf.keras.layers.Dense(1,activation='sigmoid',name='out_item')(x2) prediction = [out1,out2,out3,out4,out_class] model = tf.keras.models.Model(inputs=inputs,outputs=prediction) model.summary()
Model: "model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_2 (InputLayer) [(None, 224, 224, 3) 0 __________________________________________________________________________________________________ xception (Model) (None, 7, 7, 2048) 20861480 input_2[0][0] __________________________________________________________________________________________________ global_average_pooling2d (Globa (None, 2048) 0 xception[1][0] __________________________________________________________________________________________________ dense (Dense) (None, 2048) 4196352 global_average_pooling2d[0][0] __________________________________________________________________________________________________ dense_2 (Dense) (None, 1024) 2098176 global_average_pooling2d[0][0] __________________________________________________________________________________________________ dense_1 (Dense) (None, 256) 524544 dense[0][0] __________________________________________________________________________________________________ dense_3 (Dense) (None, 256) 262400 dense_2[0][0] __________________________________________________________________________________________________ out1 (Dense) (None, 1) 257 dense_1[0][0] __________________________________________________________________________________________________ out2 (Dense) (None, 1) 257 dense_1[0][0] __________________________________________________________________________________________________ out3 (Dense) (None, 1) 257 dense_1[0][0] __________________________________________________________________________________________________ out4 (Dense) (None, 1) 257 dense_1[0][0] __________________________________________________________________________________________________ out_item (Dense) (None, 1) 257 dense_3[0][0] ================================================================================================== Total params: 27,944,237 Trainable params: 27,889,709 Non-trainable params: 54,528 __________________________________________________________________________________________________
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), loss={'out1':'mse', 'out2':'mse', 'out3':'mse', 'out4':'mse', 'out_item':'binary_crossentropy'}, metrics=["mae","acc"]) steps_per_eooch = train_count//batch_size validation_steps = test_count//batch_size history = model.fit(train_ds,epochs=20,steps_per_epoch=steps_per_eooch,validation_data=test_ds,validation_steps=validation_steps)
Train for 184 steps, validate for 46 steps Epoch 1/20 184/184 [==============================] - 404s 2s/step - loss: 0.2214 - out1_loss: 0.0246 - out2_loss: 0.0166 - out3_loss: 0.0312 - out4_loss: 0.0228 - out_item_loss: 0.1261 - out1_mae: 0.1187 - out1_acc: 0.0000e+00 - out2_mae: 0.0961 - out2_acc: 0.0000e+00 - out3_mae: 0.1327 - out3_acc: 0.0095 - out4_mae: 0.1156 - out4_acc: 0.0048 - out_item_mae: 0.0854 - out_item_acc: 0.9467 - val_loss: 0.1970 - val_out1_loss: 0.0185 - val_out2_loss: 0.0214 - val_out3_loss: 0.0752 - val_out4_loss: 0.0735 - val_out_item_loss: 0.0083 - val_out1_mae: 0.1078 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.1255 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.2500 - val_out3_acc: 0.0054 - val_out4_mae: 0.2464 - val_out4_acc: 0.0014 - val_out_item_mae: 0.0044 - val_out_item_acc: 0.9973 Epoch 2/20 184/184 [==============================] - 400s 2s/step - loss: 0.0751 - out1_loss: 0.0111 - out2_loss: 0.0071 - out3_loss: 0.0132 - out4_loss: 0.0125 - out_item_loss: 0.0313 - out1_mae: 0.0831 - out1_acc: 0.0000e+00 - out2_mae: 0.0657 - out2_acc: 0.0000e+00 - out3_mae: 0.0899 - out3_acc: 0.0085 - out4_mae: 0.0878 - out4_acc: 0.0034 - out_item_mae: 0.0167 - out_item_acc: 0.9901 - val_loss: 0.0923 - val_out1_loss: 0.0098 - val_out2_loss: 0.0047 - val_out3_loss: 0.0400 - val_out4_loss: 0.0230 - val_out_item_loss: 0.0149 - val_out1_mae: 0.0802 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0540 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.1778 - val_out3_acc: 0.0122 - val_out4_mae: 0.1267 - val_out4_acc: 0.0027 - val_out_item_mae: 0.0051 - val_out_item_acc: 0.9959 Epoch 3/20 184/184 [==============================] - 408s 2s/step - loss: 0.0460 - out1_loss: 0.0080 - out2_loss: 0.0050 - out3_loss: 0.0098 - out4_loss: 0.0082 - out_item_loss: 0.0151 - out1_mae: 0.0698 - out1_acc: 0.0000e+00 - out2_mae: 0.0555 - out2_acc: 0.0000e+00 - out3_mae: 0.0776 - out3_acc: 0.0078 - out4_mae: 0.0711 - out4_acc: 0.0031 - out_item_mae: 0.0074 - out_item_acc: 0.9959 - val_loss: 0.0274 - val_out1_loss: 0.0064 - val_out2_loss: 0.0037 - val_out3_loss: 0.0077 - val_out4_loss: 0.0075 - val_out_item_loss: 0.0021 - val_out1_mae: 0.0622 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0479 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0677 - val_out3_acc: 0.0054 - val_out4_mae: 0.0662 - val_out4_acc: 0.0027 - val_out_item_mae: 0.0021 - val_out_item_acc: 1.0000 Epoch 4/20 184/184 [==============================] - 379s 2s/step - loss: 0.0440 - out1_loss: 0.0064 - out2_loss: 0.0043 - out3_loss: 0.0087 - out4_loss: 0.0077 - out_item_loss: 0.0169 - out1_mae: 0.0618 - out1_acc: 0.0000e+00 - out2_mae: 0.0513 - out2_acc: 0.0000e+00 - out3_mae: 0.0720 - out3_acc: 0.0078 - out4_mae: 0.0686 - out4_acc: 0.0037 - out_item_mae: 0.0096 - out_item_acc: 0.9929 - val_loss: 0.0239 - val_out1_loss: 0.0033 - val_out2_loss: 0.0023 - val_out3_loss: 0.0060 - val_out4_loss: 0.0074 - val_out_item_loss: 0.0049 - val_out1_mae: 0.0429 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0380 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0586 - val_out3_acc: 0.0122 - val_out4_mae: 0.0645 - val_out4_acc: 0.0041 - val_out_item_mae: 0.0028 - val_out_item_acc: 0.9986 Epoch 5/20 184/184 [==============================] - 421s 2s/step - loss: 0.0355 - out1_loss: 0.0051 - out2_loss: 0.0033 - out3_loss: 0.0064 - out4_loss: 0.0057 - out_item_loss: 0.0149 - out1_mae: 0.0561 - out1_acc: 0.0000e+00 - out2_mae: 0.0454 - out2_acc: 0.0000e+00 - out3_mae: 0.0623 - out3_acc: 0.0078 - out4_mae: 0.0595 - out4_acc: 0.0048 - out_item_mae: 0.0069 - out_item_acc: 0.9966 - val_loss: 0.0493 - val_out1_loss: 0.0034 - val_out2_loss: 0.0029 - val_out3_loss: 0.0247 - val_out4_loss: 0.0092 - val_out_item_loss: 0.0091 - val_out1_mae: 0.0451 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0425 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.1334 - val_out3_acc: 0.0082 - val_out4_mae: 0.0734 - val_out4_acc: 0.0095 - val_out_item_mae: 0.0048 - val_out_item_acc: 0.9959 Epoch 6/20 184/184 [==============================] - 424s 2s/step - loss: 0.0281 - out1_loss: 0.0044 - out2_loss: 0.0031 - out3_loss: 0.0061 - out4_loss: 0.0055 - out_item_loss: 0.0091 - out1_mae: 0.0515 - out1_acc: 0.0000e+00 - out2_mae: 0.0441 - out2_acc: 0.0000e+00 - out3_mae: 0.0613 - out3_acc: 0.0082 - out4_mae: 0.0576 - out4_acc: 0.0048 - out_item_mae: 0.0047 - out_item_acc: 0.9973 - val_loss: 0.0124 - val_out1_loss: 0.0026 - val_out2_loss: 0.0017 - val_out3_loss: 0.0035 - val_out4_loss: 0.0033 - val_out_item_loss: 0.0013 - val_out1_mae: 0.0382 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0318 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0429 - val_out3_acc: 0.0068 - val_out4_mae: 0.0432 - val_out4_acc: 0.0068 - val_out_item_mae: 0.0012 - val_out_item_acc: 1.0000 Epoch 7/20 184/184 [==============================] - 395s 2s/step - loss: 0.0174 - out1_loss: 0.0033 - out2_loss: 0.0027 - out3_loss: 0.0046 - out4_loss: 0.0044 - out_item_loss: 0.0025 - out1_mae: 0.0450 - out1_acc: 0.0000e+00 - out2_mae: 0.0409 - out2_acc: 0.0000e+00 - out3_mae: 0.0536 - out3_acc: 0.0075 - out4_mae: 0.0516 - out4_acc: 0.0037 - out_item_mae: 0.0018 - out_item_acc: 0.9993 - val_loss: 0.0113 - val_out1_loss: 0.0026 - val_out2_loss: 0.0015 - val_out3_loss: 0.0040 - val_out4_loss: 0.0030 - val_out_item_loss: 1.6839e-04 - val_out1_mae: 0.0401 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0304 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0508 - val_out3_acc: 0.0109 - val_out4_mae: 0.0410 - val_out4_acc: 0.0068 - val_out_item_mae: 1.6825e-04 - val_out_item_acc: 1.0000 Epoch 8/20 184/184 [==============================] - 413s 2s/step - loss: 0.0235 - out1_loss: 0.0039 - out2_loss: 0.0023 - out3_loss: 0.0047 - out4_loss: 0.0040 - out_item_loss: 0.0085 - out1_mae: 0.0491 - out1_acc: 0.0000e+00 - out2_mae: 0.0375 - out2_acc: 0.0000e+00 - out3_mae: 0.0534 - out3_acc: 0.0085 - out4_mae: 0.0499 - out4_acc: 0.0037 - out_item_mae: 0.0028 - out_item_acc: 0.9983 - val_loss: 0.0539 - val_out1_loss: 0.0033 - val_out2_loss: 0.0041 - val_out3_loss: 0.0152 - val_out4_loss: 0.0109 - val_out_item_loss: 0.0205 - val_out1_mae: 0.0431 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0503 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.1060 - val_out3_acc: 0.0082 - val_out4_mae: 0.0869 - val_out4_acc: 0.0041 - val_out_item_mae: 0.0092 - val_out_item_acc: 0.9932 Epoch 9/20 184/184 [==============================] - 407s 2s/step - loss: 0.0282 - out1_loss: 0.0031 - out2_loss: 0.0023 - out3_loss: 0.0046 - out4_loss: 0.0040 - out_item_loss: 0.0142 - out1_mae: 0.0442 - out1_acc: 0.0000e+00 - out2_mae: 0.0371 - out2_acc: 0.0000e+00 - out3_mae: 0.0538 - out3_acc: 0.0078 - out4_mae: 0.0495 - out4_acc: 0.0034 - out_item_mae: 0.0063 - out_item_acc: 0.9959 - val_loss: 0.0248 - val_out1_loss: 0.0026 - val_out2_loss: 0.0014 - val_out3_loss: 0.0139 - val_out4_loss: 0.0038 - val_out_item_loss: 0.0031 - val_out1_mae: 0.0382 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0284 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.1003 - val_out3_acc: 0.0068 - val_out4_mae: 0.0459 - val_out4_acc: 0.0054 - val_out_item_mae: 0.0016 - val_out_item_acc: 0.9986 Epoch 10/20 184/184 [==============================] - 402s 2s/step - loss: 0.0134 - out1_loss: 0.0037 - out2_loss: 0.0019 - out3_loss: 0.0039 - out4_loss: 0.0030 - out_item_loss: 8.8005e-04 - out1_mae: 0.0477 - out1_acc: 0.0000e+00 - out2_mae: 0.0338 - out2_acc: 0.0000e+00 - out3_mae: 0.0493 - out3_acc: 0.0082 - out4_mae: 0.0429 - out4_acc: 0.0048 - out_item_mae: 8.1029e-04 - out_item_acc: 1.0000 - val_loss: 0.0118 - val_out1_loss: 0.0034 - val_out2_loss: 0.0011 - val_out3_loss: 0.0049 - val_out4_loss: 0.0024 - val_out_item_loss: 1.1599e-04 - val_out1_mae: 0.0468 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0250 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0528 - val_out3_acc: 0.0054 - val_out4_mae: 0.0368 - val_out4_acc: 0.0014 - val_out_item_mae: 1.1586e-04 - val_out_item_acc: 1.0000 Epoch 11/20 184/184 [==============================] - 397s 2s/step - loss: 0.0177 - out1_loss: 0.0029 - out2_loss: 0.0021 - out3_loss: 0.0038 - out4_loss: 0.0037 - out_item_loss: 0.0053 - out1_mae: 0.0421 - out1_acc: 0.0000e+00 - out2_mae: 0.0355 - out2_acc: 0.0000e+00 - out3_mae: 0.0490 - out3_acc: 0.0088 - out4_mae: 0.0476 - out4_acc: 0.0041 - out_item_mae: 0.0033 - out_item_acc: 0.9973 - val_loss: 0.0132 - val_out1_loss: 0.0023 - val_out2_loss: 0.0015 - val_out3_loss: 0.0038 - val_out4_loss: 0.0052 - val_out_item_loss: 4.3660e-04 - val_out1_mae: 0.0383 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0296 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0487 - val_out3_acc: 0.0068 - val_out4_mae: 0.0543 - val_out4_acc: 0.0054 - val_out_item_mae: 4.3213e-04 - val_out_item_acc: 1.0000 Epoch 12/20 184/184 [==============================] - 385s 2s/step - loss: 0.0182 - out1_loss: 0.0031 - out2_loss: 0.0019 - out3_loss: 0.0033 - out4_loss: 0.0031 - out_item_loss: 0.0068 - out1_mae: 0.0430 - out1_acc: 0.0000e+00 - out2_mae: 0.0341 - out2_acc: 0.0000e+00 - out3_mae: 0.0455 - out3_acc: 0.0082 - out4_mae: 0.0435 - out4_acc: 0.0044 - out_item_mae: 0.0033 - out_item_acc: 0.9980 - val_loss: 0.0533 - val_out1_loss: 0.0100 - val_out2_loss: 0.0031 - val_out3_loss: 0.0211 - val_out4_loss: 0.0153 - val_out_item_loss: 0.0038 - val_out1_mae: 0.0685 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0422 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0879 - val_out3_acc: 0.0041 - val_out4_mae: 0.0747 - val_out4_acc: 0.0041 - val_out_item_mae: 0.0026 - val_out_item_acc: 0.9986 Epoch 13/20 184/184 [==============================] - 380s 2s/step - loss: 0.0176 - out1_loss: 0.0032 - out2_loss: 0.0018 - out3_loss: 0.0037 - out4_loss: 0.0029 - out_item_loss: 0.0061 - out1_mae: 0.0443 - out1_acc: 0.0000e+00 - out2_mae: 0.0329 - out2_acc: 0.0000e+00 - out3_mae: 0.0479 - out3_acc: 0.0082 - out4_mae: 0.0418 - out4_acc: 0.0041 - out_item_mae: 0.0023 - out_item_acc: 0.9990 - val_loss: 0.0119 - val_out1_loss: 0.0026 - val_out2_loss: 0.0011 - val_out3_loss: 0.0047 - val_out4_loss: 0.0035 - val_out_item_loss: 1.2042e-04 - val_out1_mae: 0.0390 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0250 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0511 - val_out3_acc: 0.0122 - val_out4_mae: 0.0437 - val_out4_acc: 0.0027 - val_out_item_mae: 1.2033e-04 - val_out_item_acc: 1.0000 Epoch 14/20 184/184 [==============================] - 380s 2s/step - loss: 0.0158 - out1_loss: 0.0024 - out2_loss: 0.0017 - out3_loss: 0.0032 - out4_loss: 0.0028 - out_item_loss: 0.0057 - out1_mae: 0.0380 - out1_acc: 0.0000e+00 - out2_mae: 0.0327 - out2_acc: 0.0000e+00 - out3_mae: 0.0445 - out3_acc: 0.0095 - out4_mae: 0.0412 - out4_acc: 0.0044 - out_item_mae: 0.0025 - out_item_acc: 0.9986 - val_loss: 0.0155 - val_out1_loss: 0.0031 - val_out2_loss: 0.0027 - val_out3_loss: 0.0046 - val_out4_loss: 0.0042 - val_out_item_loss: 9.4620e-04 - val_out1_mae: 0.0416 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0439 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0497 - val_out3_acc: 0.0122 - val_out4_mae: 0.0532 - val_out4_acc: 0.0027 - val_out_item_mae: 8.0296e-04 - val_out_item_acc: 1.0000 Epoch 15/20 184/184 [==============================] - 399s 2s/step - loss: 0.0149 - out1_loss: 0.0023 - out2_loss: 0.0015 - out3_loss: 0.0028 - out4_loss: 0.0026 - out_item_loss: 0.0056 - out1_mae: 0.0374 - out1_acc: 0.0000e+00 - out2_mae: 0.0308 - out2_acc: 0.0000e+00 - out3_mae: 0.0413 - out3_acc: 0.0078 - out4_mae: 0.0403 - out4_acc: 0.0044 - out_item_mae: 0.0021 - out_item_acc: 0.9986 - val_loss: 0.0109 - val_out1_loss: 0.0021 - val_out2_loss: 0.0015 - val_out3_loss: 0.0038 - val_out4_loss: 0.0030 - val_out_item_loss: 4.1310e-04 - val_out1_mae: 0.0343 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0308 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0481 - val_out3_acc: 0.0068 - val_out4_mae: 0.0395 - val_out4_acc: 0.0068 - val_out_item_mae: 4.1129e-04 - val_out_item_acc: 1.0000 Epoch 16/20 184/184 [==============================] - 344s 2s/step - loss: 0.0148 - out1_loss: 0.0021 - out2_loss: 0.0017 - out3_loss: 0.0030 - out4_loss: 0.0026 - out_item_loss: 0.0055 - out1_mae: 0.0363 - out1_acc: 0.0000e+00 - out2_mae: 0.0319 - out2_acc: 0.0000e+00 - out3_mae: 0.0432 - out3_acc: 0.0085 - out4_mae: 0.0396 - out4_acc: 0.0041 - out_item_mae: 0.0027 - out_item_acc: 0.9980 - val_loss: 0.0123 - val_out1_loss: 0.0020 - val_out2_loss: 0.0011 - val_out3_loss: 0.0050 - val_out4_loss: 0.0034 - val_out_item_loss: 6.8025e-04 - val_out1_mae: 0.0327 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0251 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0508 - val_out3_acc: 0.0095 - val_out4_mae: 0.0405 - val_out4_acc: 0.0041 - val_out_item_mae: 5.8624e-04 - val_out_item_acc: 1.0000 Epoch 17/20 184/184 [==============================] - 334s 2s/step - loss: 0.0180 - out1_loss: 0.0026 - out2_loss: 0.0017 - out3_loss: 0.0032 - out4_loss: 0.0028 - out_item_loss: 0.0076 - out1_mae: 0.0400 - out1_acc: 0.0000e+00 - out2_mae: 0.0322 - out2_acc: 0.0000e+00 - out3_mae: 0.0446 - out3_acc: 0.0075 - out4_mae: 0.0415 - out4_acc: 0.0027 - out_item_mae: 0.0036 - out_item_acc: 0.9980 - val_loss: 0.0113 - val_out1_loss: 0.0021 - val_out2_loss: 0.0014 - val_out3_loss: 0.0041 - val_out4_loss: 0.0031 - val_out_item_loss: 6.0786e-04 - val_out1_mae: 0.0326 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0302 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0394 - val_out3_acc: 0.0095 - val_out4_mae: 0.0396 - val_out4_acc: 0.0068 - val_out_item_mae: 5.8699e-04 - val_out_item_acc: 1.0000 Epoch 18/20 184/184 [==============================] - 332s 2s/step - loss: 0.0091 - out1_loss: 0.0023 - out2_loss: 0.0014 - out3_loss: 0.0025 - out4_loss: 0.0024 - out_item_loss: 4.3962e-04 - out1_mae: 0.0375 - out1_acc: 0.0000e+00 - out2_mae: 0.0290 - out2_acc: 0.0000e+00 - out3_mae: 0.0397 - out3_acc: 0.0085 - out4_mae: 0.0382 - out4_acc: 0.0048 - out_item_mae: 4.1533e-04 - out_item_acc: 1.0000 - val_loss: 0.0083 - val_out1_loss: 0.0016 - val_out2_loss: 0.0012 - val_out3_loss: 0.0032 - val_out4_loss: 0.0023 - val_out_item_loss: 4.4895e-05 - val_out1_mae: 0.0282 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0250 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0439 - val_out3_acc: 0.0054 - val_out4_mae: 0.0339 - val_out4_acc: 0.0054 - val_out_item_mae: 4.4879e-05 - val_out_item_acc: 1.0000 Epoch 19/20 184/184 [==============================] - 333s 2s/step - loss: 0.0077 - out1_loss: 0.0022 - out2_loss: 0.0014 - out3_loss: 0.0021 - out4_loss: 0.0020 - out_item_loss: 1.6014e-04 - out1_mae: 0.0359 - out1_acc: 0.0000e+00 - out2_mae: 0.0287 - out2_acc: 0.0000e+00 - out3_mae: 0.0357 - out3_acc: 0.0078 - out4_mae: 0.0346 - out4_acc: 0.0041 - out_item_mae: 1.5708e-04 - out_item_acc: 1.0000 - val_loss: 0.0049 - val_out1_loss: 8.3605e-04 - val_out2_loss: 9.8815e-04 - val_out3_loss: 0.0016 - val_out4_loss: 0.0015 - val_out_item_loss: 2.3111e-05 - val_out1_mae: 0.0220 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0233 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0306 - val_out3_acc: 0.0054 - val_out4_mae: 0.0281 - val_out4_acc: 0.0014 - val_out_item_mae: 2.3108e-05 - val_out_item_acc: 1.0000 Epoch 20/20 184/184 [==============================] - 335s 2s/step - loss: 0.0065 - out1_loss: 0.0015 - out2_loss: 0.0013 - out3_loss: 0.0018 - out4_loss: 0.0018 - out_item_loss: 1.4812e-04 - out1_mae: 0.0305 - out1_acc: 0.0000e+00 - out2_mae: 0.0285 - out2_acc: 0.0000e+00 - out3_mae: 0.0334 - out3_acc: 0.0085 - out4_mae: 0.0333 - out4_acc: 0.0041 - out_item_mae: 1.4408e-04 - out_item_acc: 1.0000 - val_loss: 0.0060 - val_out1_loss: 0.0017 - val_out2_loss: 0.0011 - val_out3_loss: 0.0019 - val_out4_loss: 0.0012 - val_out_item_loss: 2.1330e-05 - val_out1_mae: 0.0325 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0253 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0328 - val_out3_acc: 0.0054 - val_out4_mae: 0.0244 - val_out4_acc: 0.0027 - val_out_item_mae: 2.1332e-05 - val_out_item_acc: 1.0000
4. 模型评估
model.save("detect_v1.h5") new_model = tf.keras.models.load_model("detect_v1.h5") plt.figure(figsize=(8,24)) for img,_ in test_ds.skip(8).take(1): out1,out2,out3,out4,out5 = new_model.predict(img) for i in range(6): plt.subplot(6,1,i+1) plt.imshow(tf.keras.preprocessing.image.array_to_img(img[i])) xmin,ymin,xmax,ymax = out1[i]*224,out2[i]*224,out3[i]*224,out4[i]*224 #给定左下角坐标 rect = Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),fill=False,color = "red") ax = plt.gca() ax.axes.add_patch(rect) plt.title((class_label[round(float(out5[i]))]).title()) #plt.show()
从图片结果看,整体的效果还是不错的。