tebsorflow2.0 图像定位+分类(Oxford-IIIT数据集)

本文涉及的产品
服务治理 MSE Sentinel/OpenSergo,Agent数量 不受限
简介: 对于单纯的分类问题,比较容易理解,给定一副图画,我们输出一个标签的类别。而对于定位问题,需要输出四个数字(x,y,w,h),图像的某一点坐标(x,y),以及图像的宽度和高度,有了这四个数字,我们很容易找到物体的边框。

1. 导入相关包

import tensorflow as tf
import matplotlib.pyplot as plt
%matplotlib inline
from lxml import etree
import numpy as np
import glob
import matplotlib.patches as Rectangle
print(tf.__version__)
tf.test.is_gpu_available()
2.0.0
True

2. 数据的预处理

2.1 创建输入管道

images = glob.glob("./Image_location/images/*.jpg")
xmls = glob.glob("./Image_location/annotations/xmls/*.xml")
#获取文件的名称
names = [x.split("\\")[-1].split(".xml")[0] for x in xmls]
imgs_train = [img for img in images if img.split("\\")[-1].split(".jpg")[0] in names]
imgs_test = [img for img in images if img.split("\\")[-1].split(".jpg")[0] not in names]
#对其进行排序
imgs_train.sort(key=lambda x :x.split("\\")[-1].split(".jpg")[0])
xmls.sort(key=lambda x :x.split("\\")[-1].split(".xml")[0])
#排序后确定类别
names = [x.split("\\")[-1].split(".xml")[0] for x in xmls]
class_label = ["cat","dog"]
class_label_index = dict((name,index) for index,name in enumerate(class_label))
label = [class_label_index[class_label[0]] if label.istitle() else class_label_index[class_label[1]] for label in names]

对标签数据读入并封装

def to_labels(path):
    xml = open("{}".format(path)).read()
    sel = etree.HTML(xml)
    width = int(sel.xpath("//size/width/text()")[0])
    height = int(sel.xpath("//size/height/text()")[0])
    xmin = int(sel.xpath("//bndbox/xmin/text()")[0])
    xmax = int(sel.xpath("//bndbox/xmax/text()")[0])
    ymin = int(sel.xpath("//bndbox/ymin/text()")[0])
    ymax = int(sel.xpath("//bndbox/ymax/text()")[0])
    return [xmin/width,ymin/height,xmax/width,ymax/height]
#读取位置信息
labels = [to_labels(path) for path in xmls]
out1,out2,out3,out4 = list(zip(*labels))
out1 = np.array(out1)
out2 = np.array(out2)
out3 = np.array(out3)
out4 = np.array(out4)
label = np.array(label)
label_datasets = tf.data.Dataset.from_tensor_slices((out1,out2,out3,out4,label))

对所在路径下的图片进行读取

def loda_image(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img,channels=3)
    img = tf.image.resize(img,(224,224))
    img = img/127.5 - 1 #规划到-1到1之间
    return img
image_dataset = tf.data.Dataset.from_tensor_slices(imgs_train)
AUTOTUNE = tf.data.experimental.AUTOTUNE
image_dataset = image_dataset.map(loda_image,num_parallel_calls=AUTOTUNE)
dataset = tf.data.Dataset.zip((image_dataset,label_datasets))

2.2 设置训练集与测试集

#%%设置训练数据和验证集数据的大小
test_count = int(len(imgs_train)*0.2)
train_count = len(imgs_train) - test_count
print(test_count,train_count)
#跳过test_count个
dataset = dataset.shuffle(buffer_size=len(imgs_train))
train_dataset = dataset.skip(test_count)
test_dataset = dataset.take(test_count)
batch_size = 16
# 设置一个和数据集大小一致的 shuffle buffer size(随机缓冲区大小)以保证数据被充分打乱。
train_ds = train_dataset.shuffle(buffer_size=train_count).repeat().batch(batch_size)
train_ds = train_ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_ds = test_dataset.batch(batch_size)
737 2949
• 1

我们打印一个图片看其效果

from matplotlib.patches import Rectangle
for img,label in train_ds.take(1):
    plt.imshow(tf.keras.preprocessing.image.array_to_img(img[0]))
    out1,out2,out3,out4,out5= label
    xmin,ymin,xmax,ymax = out1[0].numpy()*224,out2[0].numpy()*224,out3[0].numpy()*224,out4[0].numpy()*224
    #给定左下角坐标
    rect = Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),fill=False,color = "red")
    ax = plt.gca()
    ax.axes.add_patch(rect)
    plt.title((class_label[out5[0]]).title())
    plt.show()

3. 创建模型

我们采用了Xception预训练模型进行构建网络,并采用了上节所用的多输出模式,分别输出类别和位置信息。

xcpetion = tf.keras.applications.Xception(input_shape=(224, 224, 3), 
                                               include_top=False,
                                               weights='imagenet')
inputs = tf.keras.layers.Input(shape=(224,224,3))
x = xcpetion(inputs)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x1 = tf.keras.layers.Dense(2048, activation='relu')(x)
x1 = tf.keras.layers.Dense(256, activation='relu')(x1)
out1 = tf.keras.layers.Dense(1,name="out1")(x1)
out2 = tf.keras.layers.Dense(1,name="out2")(x1)
out3 = tf.keras.layers.Dense(1,name="out3")(x1)
out4 = tf.keras.layers.Dense(1,name="out4")(x1)
x2 = tf.keras.layers.Dense(1024, activation='relu')(x)
x2 = tf.keras.layers.Dense(256, activation='relu')(x2)
out_class = tf.keras.layers.Dense(1,activation='sigmoid',name='out_item')(x2)
prediction = [out1,out2,out3,out4,out_class]
model = tf.keras.models.Model(inputs=inputs,outputs=prediction)
model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
xception (Model)                (None, 7, 7, 2048)   20861480    input_2[0][0]                    
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 2048)         0           xception[1][0]                   
__________________________________________________________________________________________________
dense (Dense)                   (None, 2048)         4196352     global_average_pooling2d[0][0]   
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 1024)         2098176     global_average_pooling2d[0][0]   
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 256)          524544      dense[0][0]                      
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 256)          262400      dense_2[0][0]                    
__________________________________________________________________________________________________
out1 (Dense)                    (None, 1)            257         dense_1[0][0]                    
__________________________________________________________________________________________________
out2 (Dense)                    (None, 1)            257         dense_1[0][0]                    
__________________________________________________________________________________________________
out3 (Dense)                    (None, 1)            257         dense_1[0][0]                    
__________________________________________________________________________________________________
out4 (Dense)                    (None, 1)            257         dense_1[0][0]                    
__________________________________________________________________________________________________
out_item (Dense)                (None, 1)            257         dense_3[0][0]                    
==================================================================================================
Total params: 27,944,237
Trainable params: 27,889,709
Non-trainable params: 54,528
__________________________________________________________________________________________________
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss={'out1':'mse',
                    'out2':'mse',
                    'out3':'mse',
                    'out4':'mse',
                    'out_item':'binary_crossentropy'},
              metrics=["mae","acc"])
steps_per_eooch = train_count//batch_size
validation_steps = test_count//batch_size
history = model.fit(train_ds,epochs=20,steps_per_epoch=steps_per_eooch,validation_data=test_ds,validation_steps=validation_steps)
Train for 184 steps, validate for 46 steps
Epoch 1/20
184/184 [==============================] - 404s 2s/step - loss: 0.2214 - out1_loss: 0.0246 - out2_loss: 0.0166 - out3_loss: 0.0312 - out4_loss: 0.0228 - out_item_loss: 0.1261 - out1_mae: 0.1187 - out1_acc: 0.0000e+00 - out2_mae: 0.0961 - out2_acc: 0.0000e+00 - out3_mae: 0.1327 - out3_acc: 0.0095 - out4_mae: 0.1156 - out4_acc: 0.0048 - out_item_mae: 0.0854 - out_item_acc: 0.9467 - val_loss: 0.1970 - val_out1_loss: 0.0185 - val_out2_loss: 0.0214 - val_out3_loss: 0.0752 - val_out4_loss: 0.0735 - val_out_item_loss: 0.0083 - val_out1_mae: 0.1078 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.1255 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.2500 - val_out3_acc: 0.0054 - val_out4_mae: 0.2464 - val_out4_acc: 0.0014 - val_out_item_mae: 0.0044 - val_out_item_acc: 0.9973
Epoch 2/20
184/184 [==============================] - 400s 2s/step - loss: 0.0751 - out1_loss: 0.0111 - out2_loss: 0.0071 - out3_loss: 0.0132 - out4_loss: 0.0125 - out_item_loss: 0.0313 - out1_mae: 0.0831 - out1_acc: 0.0000e+00 - out2_mae: 0.0657 - out2_acc: 0.0000e+00 - out3_mae: 0.0899 - out3_acc: 0.0085 - out4_mae: 0.0878 - out4_acc: 0.0034 - out_item_mae: 0.0167 - out_item_acc: 0.9901 - val_loss: 0.0923 - val_out1_loss: 0.0098 - val_out2_loss: 0.0047 - val_out3_loss: 0.0400 - val_out4_loss: 0.0230 - val_out_item_loss: 0.0149 - val_out1_mae: 0.0802 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0540 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.1778 - val_out3_acc: 0.0122 - val_out4_mae: 0.1267 - val_out4_acc: 0.0027 - val_out_item_mae: 0.0051 - val_out_item_acc: 0.9959
Epoch 3/20
184/184 [==============================] - 408s 2s/step - loss: 0.0460 - out1_loss: 0.0080 - out2_loss: 0.0050 - out3_loss: 0.0098 - out4_loss: 0.0082 - out_item_loss: 0.0151 - out1_mae: 0.0698 - out1_acc: 0.0000e+00 - out2_mae: 0.0555 - out2_acc: 0.0000e+00 - out3_mae: 0.0776 - out3_acc: 0.0078 - out4_mae: 0.0711 - out4_acc: 0.0031 - out_item_mae: 0.0074 - out_item_acc: 0.9959 - val_loss: 0.0274 - val_out1_loss: 0.0064 - val_out2_loss: 0.0037 - val_out3_loss: 0.0077 - val_out4_loss: 0.0075 - val_out_item_loss: 0.0021 - val_out1_mae: 0.0622 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0479 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0677 - val_out3_acc: 0.0054 - val_out4_mae: 0.0662 - val_out4_acc: 0.0027 - val_out_item_mae: 0.0021 - val_out_item_acc: 1.0000
Epoch 4/20
184/184 [==============================] - 379s 2s/step - loss: 0.0440 - out1_loss: 0.0064 - out2_loss: 0.0043 - out3_loss: 0.0087 - out4_loss: 0.0077 - out_item_loss: 0.0169 - out1_mae: 0.0618 - out1_acc: 0.0000e+00 - out2_mae: 0.0513 - out2_acc: 0.0000e+00 - out3_mae: 0.0720 - out3_acc: 0.0078 - out4_mae: 0.0686 - out4_acc: 0.0037 - out_item_mae: 0.0096 - out_item_acc: 0.9929 - val_loss: 0.0239 - val_out1_loss: 0.0033 - val_out2_loss: 0.0023 - val_out3_loss: 0.0060 - val_out4_loss: 0.0074 - val_out_item_loss: 0.0049 - val_out1_mae: 0.0429 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0380 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0586 - val_out3_acc: 0.0122 - val_out4_mae: 0.0645 - val_out4_acc: 0.0041 - val_out_item_mae: 0.0028 - val_out_item_acc: 0.9986
Epoch 5/20
184/184 [==============================] - 421s 2s/step - loss: 0.0355 - out1_loss: 0.0051 - out2_loss: 0.0033 - out3_loss: 0.0064 - out4_loss: 0.0057 - out_item_loss: 0.0149 - out1_mae: 0.0561 - out1_acc: 0.0000e+00 - out2_mae: 0.0454 - out2_acc: 0.0000e+00 - out3_mae: 0.0623 - out3_acc: 0.0078 - out4_mae: 0.0595 - out4_acc: 0.0048 - out_item_mae: 0.0069 - out_item_acc: 0.9966 - val_loss: 0.0493 - val_out1_loss: 0.0034 - val_out2_loss: 0.0029 - val_out3_loss: 0.0247 - val_out4_loss: 0.0092 - val_out_item_loss: 0.0091 - val_out1_mae: 0.0451 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0425 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.1334 - val_out3_acc: 0.0082 - val_out4_mae: 0.0734 - val_out4_acc: 0.0095 - val_out_item_mae: 0.0048 - val_out_item_acc: 0.9959
Epoch 6/20
184/184 [==============================] - 424s 2s/step - loss: 0.0281 - out1_loss: 0.0044 - out2_loss: 0.0031 - out3_loss: 0.0061 - out4_loss: 0.0055 - out_item_loss: 0.0091 - out1_mae: 0.0515 - out1_acc: 0.0000e+00 - out2_mae: 0.0441 - out2_acc: 0.0000e+00 - out3_mae: 0.0613 - out3_acc: 0.0082 - out4_mae: 0.0576 - out4_acc: 0.0048 - out_item_mae: 0.0047 - out_item_acc: 0.9973 - val_loss: 0.0124 - val_out1_loss: 0.0026 - val_out2_loss: 0.0017 - val_out3_loss: 0.0035 - val_out4_loss: 0.0033 - val_out_item_loss: 0.0013 - val_out1_mae: 0.0382 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0318 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0429 - val_out3_acc: 0.0068 - val_out4_mae: 0.0432 - val_out4_acc: 0.0068 - val_out_item_mae: 0.0012 - val_out_item_acc: 1.0000
Epoch 7/20
184/184 [==============================] - 395s 2s/step - loss: 0.0174 - out1_loss: 0.0033 - out2_loss: 0.0027 - out3_loss: 0.0046 - out4_loss: 0.0044 - out_item_loss: 0.0025 - out1_mae: 0.0450 - out1_acc: 0.0000e+00 - out2_mae: 0.0409 - out2_acc: 0.0000e+00 - out3_mae: 0.0536 - out3_acc: 0.0075 - out4_mae: 0.0516 - out4_acc: 0.0037 - out_item_mae: 0.0018 - out_item_acc: 0.9993 - val_loss: 0.0113 - val_out1_loss: 0.0026 - val_out2_loss: 0.0015 - val_out3_loss: 0.0040 - val_out4_loss: 0.0030 - val_out_item_loss: 1.6839e-04 - val_out1_mae: 0.0401 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0304 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0508 - val_out3_acc: 0.0109 - val_out4_mae: 0.0410 - val_out4_acc: 0.0068 - val_out_item_mae: 1.6825e-04 - val_out_item_acc: 1.0000
Epoch 8/20
184/184 [==============================] - 413s 2s/step - loss: 0.0235 - out1_loss: 0.0039 - out2_loss: 0.0023 - out3_loss: 0.0047 - out4_loss: 0.0040 - out_item_loss: 0.0085 - out1_mae: 0.0491 - out1_acc: 0.0000e+00 - out2_mae: 0.0375 - out2_acc: 0.0000e+00 - out3_mae: 0.0534 - out3_acc: 0.0085 - out4_mae: 0.0499 - out4_acc: 0.0037 - out_item_mae: 0.0028 - out_item_acc: 0.9983 - val_loss: 0.0539 - val_out1_loss: 0.0033 - val_out2_loss: 0.0041 - val_out3_loss: 0.0152 - val_out4_loss: 0.0109 - val_out_item_loss: 0.0205 - val_out1_mae: 0.0431 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0503 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.1060 - val_out3_acc: 0.0082 - val_out4_mae: 0.0869 - val_out4_acc: 0.0041 - val_out_item_mae: 0.0092 - val_out_item_acc: 0.9932
Epoch 9/20
184/184 [==============================] - 407s 2s/step - loss: 0.0282 - out1_loss: 0.0031 - out2_loss: 0.0023 - out3_loss: 0.0046 - out4_loss: 0.0040 - out_item_loss: 0.0142 - out1_mae: 0.0442 - out1_acc: 0.0000e+00 - out2_mae: 0.0371 - out2_acc: 0.0000e+00 - out3_mae: 0.0538 - out3_acc: 0.0078 - out4_mae: 0.0495 - out4_acc: 0.0034 - out_item_mae: 0.0063 - out_item_acc: 0.9959 - val_loss: 0.0248 - val_out1_loss: 0.0026 - val_out2_loss: 0.0014 - val_out3_loss: 0.0139 - val_out4_loss: 0.0038 - val_out_item_loss: 0.0031 - val_out1_mae: 0.0382 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0284 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.1003 - val_out3_acc: 0.0068 - val_out4_mae: 0.0459 - val_out4_acc: 0.0054 - val_out_item_mae: 0.0016 - val_out_item_acc: 0.9986
Epoch 10/20
184/184 [==============================] - 402s 2s/step - loss: 0.0134 - out1_loss: 0.0037 - out2_loss: 0.0019 - out3_loss: 0.0039 - out4_loss: 0.0030 - out_item_loss: 8.8005e-04 - out1_mae: 0.0477 - out1_acc: 0.0000e+00 - out2_mae: 0.0338 - out2_acc: 0.0000e+00 - out3_mae: 0.0493 - out3_acc: 0.0082 - out4_mae: 0.0429 - out4_acc: 0.0048 - out_item_mae: 8.1029e-04 - out_item_acc: 1.0000 - val_loss: 0.0118 - val_out1_loss: 0.0034 - val_out2_loss: 0.0011 - val_out3_loss: 0.0049 - val_out4_loss: 0.0024 - val_out_item_loss: 1.1599e-04 - val_out1_mae: 0.0468 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0250 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0528 - val_out3_acc: 0.0054 - val_out4_mae: 0.0368 - val_out4_acc: 0.0014 - val_out_item_mae: 1.1586e-04 - val_out_item_acc: 1.0000
Epoch 11/20
184/184 [==============================] - 397s 2s/step - loss: 0.0177 - out1_loss: 0.0029 - out2_loss: 0.0021 - out3_loss: 0.0038 - out4_loss: 0.0037 - out_item_loss: 0.0053 - out1_mae: 0.0421 - out1_acc: 0.0000e+00 - out2_mae: 0.0355 - out2_acc: 0.0000e+00 - out3_mae: 0.0490 - out3_acc: 0.0088 - out4_mae: 0.0476 - out4_acc: 0.0041 - out_item_mae: 0.0033 - out_item_acc: 0.9973 - val_loss: 0.0132 - val_out1_loss: 0.0023 - val_out2_loss: 0.0015 - val_out3_loss: 0.0038 - val_out4_loss: 0.0052 - val_out_item_loss: 4.3660e-04 - val_out1_mae: 0.0383 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0296 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0487 - val_out3_acc: 0.0068 - val_out4_mae: 0.0543 - val_out4_acc: 0.0054 - val_out_item_mae: 4.3213e-04 - val_out_item_acc: 1.0000
Epoch 12/20
184/184 [==============================] - 385s 2s/step - loss: 0.0182 - out1_loss: 0.0031 - out2_loss: 0.0019 - out3_loss: 0.0033 - out4_loss: 0.0031 - out_item_loss: 0.0068 - out1_mae: 0.0430 - out1_acc: 0.0000e+00 - out2_mae: 0.0341 - out2_acc: 0.0000e+00 - out3_mae: 0.0455 - out3_acc: 0.0082 - out4_mae: 0.0435 - out4_acc: 0.0044 - out_item_mae: 0.0033 - out_item_acc: 0.9980 - val_loss: 0.0533 - val_out1_loss: 0.0100 - val_out2_loss: 0.0031 - val_out3_loss: 0.0211 - val_out4_loss: 0.0153 - val_out_item_loss: 0.0038 - val_out1_mae: 0.0685 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0422 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0879 - val_out3_acc: 0.0041 - val_out4_mae: 0.0747 - val_out4_acc: 0.0041 - val_out_item_mae: 0.0026 - val_out_item_acc: 0.9986
Epoch 13/20
184/184 [==============================] - 380s 2s/step - loss: 0.0176 - out1_loss: 0.0032 - out2_loss: 0.0018 - out3_loss: 0.0037 - out4_loss: 0.0029 - out_item_loss: 0.0061 - out1_mae: 0.0443 - out1_acc: 0.0000e+00 - out2_mae: 0.0329 - out2_acc: 0.0000e+00 - out3_mae: 0.0479 - out3_acc: 0.0082 - out4_mae: 0.0418 - out4_acc: 0.0041 - out_item_mae: 0.0023 - out_item_acc: 0.9990 - val_loss: 0.0119 - val_out1_loss: 0.0026 - val_out2_loss: 0.0011 - val_out3_loss: 0.0047 - val_out4_loss: 0.0035 - val_out_item_loss: 1.2042e-04 - val_out1_mae: 0.0390 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0250 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0511 - val_out3_acc: 0.0122 - val_out4_mae: 0.0437 - val_out4_acc: 0.0027 - val_out_item_mae: 1.2033e-04 - val_out_item_acc: 1.0000
Epoch 14/20
184/184 [==============================] - 380s 2s/step - loss: 0.0158 - out1_loss: 0.0024 - out2_loss: 0.0017 - out3_loss: 0.0032 - out4_loss: 0.0028 - out_item_loss: 0.0057 - out1_mae: 0.0380 - out1_acc: 0.0000e+00 - out2_mae: 0.0327 - out2_acc: 0.0000e+00 - out3_mae: 0.0445 - out3_acc: 0.0095 - out4_mae: 0.0412 - out4_acc: 0.0044 - out_item_mae: 0.0025 - out_item_acc: 0.9986 - val_loss: 0.0155 - val_out1_loss: 0.0031 - val_out2_loss: 0.0027 - val_out3_loss: 0.0046 - val_out4_loss: 0.0042 - val_out_item_loss: 9.4620e-04 - val_out1_mae: 0.0416 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0439 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0497 - val_out3_acc: 0.0122 - val_out4_mae: 0.0532 - val_out4_acc: 0.0027 - val_out_item_mae: 8.0296e-04 - val_out_item_acc: 1.0000
Epoch 15/20
184/184 [==============================] - 399s 2s/step - loss: 0.0149 - out1_loss: 0.0023 - out2_loss: 0.0015 - out3_loss: 0.0028 - out4_loss: 0.0026 - out_item_loss: 0.0056 - out1_mae: 0.0374 - out1_acc: 0.0000e+00 - out2_mae: 0.0308 - out2_acc: 0.0000e+00 - out3_mae: 0.0413 - out3_acc: 0.0078 - out4_mae: 0.0403 - out4_acc: 0.0044 - out_item_mae: 0.0021 - out_item_acc: 0.9986 - val_loss: 0.0109 - val_out1_loss: 0.0021 - val_out2_loss: 0.0015 - val_out3_loss: 0.0038 - val_out4_loss: 0.0030 - val_out_item_loss: 4.1310e-04 - val_out1_mae: 0.0343 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0308 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0481 - val_out3_acc: 0.0068 - val_out4_mae: 0.0395 - val_out4_acc: 0.0068 - val_out_item_mae: 4.1129e-04 - val_out_item_acc: 1.0000
Epoch 16/20
184/184 [==============================] - 344s 2s/step - loss: 0.0148 - out1_loss: 0.0021 - out2_loss: 0.0017 - out3_loss: 0.0030 - out4_loss: 0.0026 - out_item_loss: 0.0055 - out1_mae: 0.0363 - out1_acc: 0.0000e+00 - out2_mae: 0.0319 - out2_acc: 0.0000e+00 - out3_mae: 0.0432 - out3_acc: 0.0085 - out4_mae: 0.0396 - out4_acc: 0.0041 - out_item_mae: 0.0027 - out_item_acc: 0.9980 - val_loss: 0.0123 - val_out1_loss: 0.0020 - val_out2_loss: 0.0011 - val_out3_loss: 0.0050 - val_out4_loss: 0.0034 - val_out_item_loss: 6.8025e-04 - val_out1_mae: 0.0327 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0251 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0508 - val_out3_acc: 0.0095 - val_out4_mae: 0.0405 - val_out4_acc: 0.0041 - val_out_item_mae: 5.8624e-04 - val_out_item_acc: 1.0000
Epoch 17/20
184/184 [==============================] - 334s 2s/step - loss: 0.0180 - out1_loss: 0.0026 - out2_loss: 0.0017 - out3_loss: 0.0032 - out4_loss: 0.0028 - out_item_loss: 0.0076 - out1_mae: 0.0400 - out1_acc: 0.0000e+00 - out2_mae: 0.0322 - out2_acc: 0.0000e+00 - out3_mae: 0.0446 - out3_acc: 0.0075 - out4_mae: 0.0415 - out4_acc: 0.0027 - out_item_mae: 0.0036 - out_item_acc: 0.9980 - val_loss: 0.0113 - val_out1_loss: 0.0021 - val_out2_loss: 0.0014 - val_out3_loss: 0.0041 - val_out4_loss: 0.0031 - val_out_item_loss: 6.0786e-04 - val_out1_mae: 0.0326 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0302 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0394 - val_out3_acc: 0.0095 - val_out4_mae: 0.0396 - val_out4_acc: 0.0068 - val_out_item_mae: 5.8699e-04 - val_out_item_acc: 1.0000
Epoch 18/20
184/184 [==============================] - 332s 2s/step - loss: 0.0091 - out1_loss: 0.0023 - out2_loss: 0.0014 - out3_loss: 0.0025 - out4_loss: 0.0024 - out_item_loss: 4.3962e-04 - out1_mae: 0.0375 - out1_acc: 0.0000e+00 - out2_mae: 0.0290 - out2_acc: 0.0000e+00 - out3_mae: 0.0397 - out3_acc: 0.0085 - out4_mae: 0.0382 - out4_acc: 0.0048 - out_item_mae: 4.1533e-04 - out_item_acc: 1.0000 - val_loss: 0.0083 - val_out1_loss: 0.0016 - val_out2_loss: 0.0012 - val_out3_loss: 0.0032 - val_out4_loss: 0.0023 - val_out_item_loss: 4.4895e-05 - val_out1_mae: 0.0282 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0250 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0439 - val_out3_acc: 0.0054 - val_out4_mae: 0.0339 - val_out4_acc: 0.0054 - val_out_item_mae: 4.4879e-05 - val_out_item_acc: 1.0000
Epoch 19/20
184/184 [==============================] - 333s 2s/step - loss: 0.0077 - out1_loss: 0.0022 - out2_loss: 0.0014 - out3_loss: 0.0021 - out4_loss: 0.0020 - out_item_loss: 1.6014e-04 - out1_mae: 0.0359 - out1_acc: 0.0000e+00 - out2_mae: 0.0287 - out2_acc: 0.0000e+00 - out3_mae: 0.0357 - out3_acc: 0.0078 - out4_mae: 0.0346 - out4_acc: 0.0041 - out_item_mae: 1.5708e-04 - out_item_acc: 1.0000 - val_loss: 0.0049 - val_out1_loss: 8.3605e-04 - val_out2_loss: 9.8815e-04 - val_out3_loss: 0.0016 - val_out4_loss: 0.0015 - val_out_item_loss: 2.3111e-05 - val_out1_mae: 0.0220 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0233 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0306 - val_out3_acc: 0.0054 - val_out4_mae: 0.0281 - val_out4_acc: 0.0014 - val_out_item_mae: 2.3108e-05 - val_out_item_acc: 1.0000
Epoch 20/20
184/184 [==============================] - 335s 2s/step - loss: 0.0065 - out1_loss: 0.0015 - out2_loss: 0.0013 - out3_loss: 0.0018 - out4_loss: 0.0018 - out_item_loss: 1.4812e-04 - out1_mae: 0.0305 - out1_acc: 0.0000e+00 - out2_mae: 0.0285 - out2_acc: 0.0000e+00 - out3_mae: 0.0334 - out3_acc: 0.0085 - out4_mae: 0.0333 - out4_acc: 0.0041 - out_item_mae: 1.4408e-04 - out_item_acc: 1.0000 - val_loss: 0.0060 - val_out1_loss: 0.0017 - val_out2_loss: 0.0011 - val_out3_loss: 0.0019 - val_out4_loss: 0.0012 - val_out_item_loss: 2.1330e-05 - val_out1_mae: 0.0325 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0253 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0328 - val_out3_acc: 0.0054 - val_out4_mae: 0.0244 - val_out4_acc: 0.0027 - val_out_item_mae: 2.1332e-05 - val_out_item_acc: 1.0000

4. 模型评估

model.save("detect_v1.h5")
new_model = tf.keras.models.load_model("detect_v1.h5")
plt.figure(figsize=(8,24))
for img,_ in test_ds.skip(8).take(1):
    out1,out2,out3,out4,out5 = new_model.predict(img)
    for i in range(6):
        plt.subplot(6,1,i+1)
        plt.imshow(tf.keras.preprocessing.image.array_to_img(img[i]))
        xmin,ymin,xmax,ymax = out1[i]*224,out2[i]*224,out3[i]*224,out4[i]*224
        #给定左下角坐标
        rect = Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),fill=False,color = "red")
        ax = plt.gca()
        ax.axes.add_patch(rect)
        plt.title((class_label[round(float(out5[i]))]).title())
       #plt.show()

从图片结果看,整体的效果还是不错的。

相关实践学习
基于MSE实现微服务的全链路灰度
通过本场景的实验操作,您将了解并实现在线业务的微服务全链路灰度能力。
相关文章
|
8月前
|
编解码 自然语言处理 数据挖掘
Recognize Anything:一个强大的图像标记模型
Recognize Anything是一种新的图像标记基础模型,与传统模型不同,它不依赖于手动注释进行训练
105 2
|
人工智能 数据可视化 数据处理
快速在 PaddleLabel 标注的花朵分类数据集上展示如何应用 PaddleX 训练 MobileNetV3_ssld 网络
快速在 PaddleLabel 标注的花朵分类数据集上展示如何应用 PaddleX 训练 MobileNetV3_ssld 网络
614 0
快速在 PaddleLabel 标注的花朵分类数据集上展示如何应用 PaddleX 训练 MobileNetV3_ssld 网络
【图像分类数据集】非常全面实用的垃圾分类图片数据集共享
【图像分类数据集】非常全面实用的垃圾分类图片数据集共享
741 18
【图像分类数据集】非常全面实用的垃圾分类图片数据集共享
|
4天前
|
测试技术
Vript:最为详细的视频文本数据集,每个视频片段平均超过140词标注 | 多模态大模型,文生视频
[Vript](https://github.com/mutonix/Vript) 是一个大规模的细粒度视频文本数据集,包含12K个高分辨率视频和400k+片段,以视频脚本形式进行密集注释,每个场景平均有145个单词的标题。除了视觉信息,还转录了画外音,提供额外背景。新发布的Vript-Bench基准包括三个挑战性任务:Vript-CAP(详细视频描述)、Vript-RR(视频推理)和Vript-ERO(事件时序推理),旨在推动视频理解的发展。
40 1
Vript:最为详细的视频文本数据集,每个视频片段平均超过140词标注 | 多模态大模型,文生视频
|
6月前
|
存储 传感器 数据可视化
3D目标检测数据集 KITTI(标签格式解析、3D框可视化、点云转图像、BEV鸟瞰图)
本文介绍在3D目标检测中,理解和使用KITTI 数据集,包括KITTI 的基本情况、下载数据集、标签格式解析、3D框可视化、点云转图像、画BEV鸟瞰图等,并配有实现代码。
556 0
|
9月前
|
JSON 算法 数据格式
优化cv2.findContours()函数提取的目标边界点,使语义分割进行远监督辅助标注
可以看到cv2.findContours()函数可以将目标的所有边界点都进行导出来,但是他的点存在一个问题,太过密集,如果我们想将语义分割的结果重新导出成labelme格式的json文件进行修正时,这就会存在点太密集没有办法进行修改,这里展示一个示例:没有对导出的结果进行修正,在labelme中的效果图。
95 0
|
10月前
|
机器学习/深度学习 存储 自然语言处理
使用特征包方法进行图像类别分类
使用特征包方法进行图像类别分类。这种技术通常也被称为词袋。视觉图像分类是为受测图像分配类别标签的过程。类别可能包含代表几乎任何东西的图像,例如狗、猫、火车、船。
65 0
|
10月前
|
存储 机器学习/深度学习 编解码
使用训练分类网络预处理多分辨率图像
说明如何准备用于读取和预处理可能不适合内存的多分辨率全玻片图像 (WSI) 的数据存储。肿瘤分类的深度学习方法依赖于数字病理学,其中整个组织切片被成像和数字化。生成的 WSI 具有高分辨率,大约为 200,000 x 100,000 像素。WSI 通常以多分辨率格式存储,以促进图像的高效显示、导航和处理。 读取和处理WSI数据。这些对象有助于使用多个分辨率级别,并且不需要将图像加载到核心内存中。此示例演示如何使用较低分辨率的图像数据从较精细的级别有效地准备数据。可以使用处理后的数据来训练分类深度学习网络。
130 0
|
机器学习/深度学习 编解码 算法
图像目标分割_4 DeepLab-V1
相比于传统的视觉算法(SIFT或HOG),Deep-CNN以其end-to-end方式获得了很好的效果。这样的成功部分可以归功于Deep-CNN对图像转换的平移不变性(invariance),这根本是源于重复的池化和下采样组合层。平移不变性增强了对数据分层抽象的能力,但同时可能会阻碍低级(low-level)视觉任务,例如姿态估计、语义分割等,在这些任务中我们倾向于精确的定位而不是抽象的空间关系。
66 0
图像目标分割_4 DeepLab-V1
|
机器学习/深度学习 编解码 TensorFlow
图像目标分割_5 DeepLab V2 & V3& V3+
DeepLab采用最新的ResNet图像分类深度卷积神经网络构建,与原来基于VGG-16的网络相比,取得了更好的语义分割性能。
267 0