tebsorflow2.0 图像定位+分类(Oxford-IIIT数据集)

本文涉及的产品
注册配置 MSE Nacos/ZooKeeper,118元/月
云原生网关 MSE Higress,422元/月
服务治理 MSE Sentinel/OpenSergo,Agent数量 不受限
简介: 对于单纯的分类问题,比较容易理解,给定一副图画,我们输出一个标签的类别。而对于定位问题,需要输出四个数字(x,y,w,h),图像的某一点坐标(x,y),以及图像的宽度和高度,有了这四个数字,我们很容易找到物体的边框。

1. 导入相关包

import tensorflow as tf
import matplotlib.pyplot as plt
%matplotlib inline
from lxml import etree
import numpy as np
import glob
import matplotlib.patches as Rectangle
print(tf.__version__)
tf.test.is_gpu_available()
2.0.0
True

2. 数据的预处理

2.1 创建输入管道

images = glob.glob("./Image_location/images/*.jpg")
xmls = glob.glob("./Image_location/annotations/xmls/*.xml")
#获取文件的名称
names = [x.split("\\")[-1].split(".xml")[0] for x in xmls]
imgs_train = [img for img in images if img.split("\\")[-1].split(".jpg")[0] in names]
imgs_test = [img for img in images if img.split("\\")[-1].split(".jpg")[0] not in names]
#对其进行排序
imgs_train.sort(key=lambda x :x.split("\\")[-1].split(".jpg")[0])
xmls.sort(key=lambda x :x.split("\\")[-1].split(".xml")[0])
#排序后确定类别
names = [x.split("\\")[-1].split(".xml")[0] for x in xmls]
class_label = ["cat","dog"]
class_label_index = dict((name,index) for index,name in enumerate(class_label))
label = [class_label_index[class_label[0]] if label.istitle() else class_label_index[class_label[1]] for label in names]

对标签数据读入并封装

def to_labels(path):
    xml = open("{}".format(path)).read()
    sel = etree.HTML(xml)
    width = int(sel.xpath("//size/width/text()")[0])
    height = int(sel.xpath("//size/height/text()")[0])
    xmin = int(sel.xpath("//bndbox/xmin/text()")[0])
    xmax = int(sel.xpath("//bndbox/xmax/text()")[0])
    ymin = int(sel.xpath("//bndbox/ymin/text()")[0])
    ymax = int(sel.xpath("//bndbox/ymax/text()")[0])
    return [xmin/width,ymin/height,xmax/width,ymax/height]
#读取位置信息
labels = [to_labels(path) for path in xmls]
out1,out2,out3,out4 = list(zip(*labels))
out1 = np.array(out1)
out2 = np.array(out2)
out3 = np.array(out3)
out4 = np.array(out4)
label = np.array(label)
label_datasets = tf.data.Dataset.from_tensor_slices((out1,out2,out3,out4,label))

对所在路径下的图片进行读取

def loda_image(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img,channels=3)
    img = tf.image.resize(img,(224,224))
    img = img/127.5 - 1 #规划到-1到1之间
    return img
image_dataset = tf.data.Dataset.from_tensor_slices(imgs_train)
AUTOTUNE = tf.data.experimental.AUTOTUNE
image_dataset = image_dataset.map(loda_image,num_parallel_calls=AUTOTUNE)
dataset = tf.data.Dataset.zip((image_dataset,label_datasets))

2.2 设置训练集与测试集

#%%设置训练数据和验证集数据的大小
test_count = int(len(imgs_train)*0.2)
train_count = len(imgs_train) - test_count
print(test_count,train_count)
#跳过test_count个
dataset = dataset.shuffle(buffer_size=len(imgs_train))
train_dataset = dataset.skip(test_count)
test_dataset = dataset.take(test_count)
batch_size = 16
# 设置一个和数据集大小一致的 shuffle buffer size(随机缓冲区大小)以保证数据被充分打乱。
train_ds = train_dataset.shuffle(buffer_size=train_count).repeat().batch(batch_size)
train_ds = train_ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_ds = test_dataset.batch(batch_size)
737 2949
• 1

我们打印一个图片看其效果

from matplotlib.patches import Rectangle
for img,label in train_ds.take(1):
    plt.imshow(tf.keras.preprocessing.image.array_to_img(img[0]))
    out1,out2,out3,out4,out5= label
    xmin,ymin,xmax,ymax = out1[0].numpy()*224,out2[0].numpy()*224,out3[0].numpy()*224,out4[0].numpy()*224
    #给定左下角坐标
    rect = Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),fill=False,color = "red")
    ax = plt.gca()
    ax.axes.add_patch(rect)
    plt.title((class_label[out5[0]]).title())
    plt.show()

3. 创建模型

我们采用了Xception预训练模型进行构建网络,并采用了上节所用的多输出模式,分别输出类别和位置信息。

xcpetion = tf.keras.applications.Xception(input_shape=(224, 224, 3), 
                                               include_top=False,
                                               weights='imagenet')
inputs = tf.keras.layers.Input(shape=(224,224,3))
x = xcpetion(inputs)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x1 = tf.keras.layers.Dense(2048, activation='relu')(x)
x1 = tf.keras.layers.Dense(256, activation='relu')(x1)
out1 = tf.keras.layers.Dense(1,name="out1")(x1)
out2 = tf.keras.layers.Dense(1,name="out2")(x1)
out3 = tf.keras.layers.Dense(1,name="out3")(x1)
out4 = tf.keras.layers.Dense(1,name="out4")(x1)
x2 = tf.keras.layers.Dense(1024, activation='relu')(x)
x2 = tf.keras.layers.Dense(256, activation='relu')(x2)
out_class = tf.keras.layers.Dense(1,activation='sigmoid',name='out_item')(x2)
prediction = [out1,out2,out3,out4,out_class]
model = tf.keras.models.Model(inputs=inputs,outputs=prediction)
model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
xception (Model)                (None, 7, 7, 2048)   20861480    input_2[0][0]                    
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 2048)         0           xception[1][0]                   
__________________________________________________________________________________________________
dense (Dense)                   (None, 2048)         4196352     global_average_pooling2d[0][0]   
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 1024)         2098176     global_average_pooling2d[0][0]   
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 256)          524544      dense[0][0]                      
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 256)          262400      dense_2[0][0]                    
__________________________________________________________________________________________________
out1 (Dense)                    (None, 1)            257         dense_1[0][0]                    
__________________________________________________________________________________________________
out2 (Dense)                    (None, 1)            257         dense_1[0][0]                    
__________________________________________________________________________________________________
out3 (Dense)                    (None, 1)            257         dense_1[0][0]                    
__________________________________________________________________________________________________
out4 (Dense)                    (None, 1)            257         dense_1[0][0]                    
__________________________________________________________________________________________________
out_item (Dense)                (None, 1)            257         dense_3[0][0]                    
==================================================================================================
Total params: 27,944,237
Trainable params: 27,889,709
Non-trainable params: 54,528
__________________________________________________________________________________________________
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss={'out1':'mse',
                    'out2':'mse',
                    'out3':'mse',
                    'out4':'mse',
                    'out_item':'binary_crossentropy'},
              metrics=["mae","acc"])
steps_per_eooch = train_count//batch_size
validation_steps = test_count//batch_size
history = model.fit(train_ds,epochs=20,steps_per_epoch=steps_per_eooch,validation_data=test_ds,validation_steps=validation_steps)
Train for 184 steps, validate for 46 steps
Epoch 1/20
184/184 [==============================] - 404s 2s/step - loss: 0.2214 - out1_loss: 0.0246 - out2_loss: 0.0166 - out3_loss: 0.0312 - out4_loss: 0.0228 - out_item_loss: 0.1261 - out1_mae: 0.1187 - out1_acc: 0.0000e+00 - out2_mae: 0.0961 - out2_acc: 0.0000e+00 - out3_mae: 0.1327 - out3_acc: 0.0095 - out4_mae: 0.1156 - out4_acc: 0.0048 - out_item_mae: 0.0854 - out_item_acc: 0.9467 - val_loss: 0.1970 - val_out1_loss: 0.0185 - val_out2_loss: 0.0214 - val_out3_loss: 0.0752 - val_out4_loss: 0.0735 - val_out_item_loss: 0.0083 - val_out1_mae: 0.1078 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.1255 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.2500 - val_out3_acc: 0.0054 - val_out4_mae: 0.2464 - val_out4_acc: 0.0014 - val_out_item_mae: 0.0044 - val_out_item_acc: 0.9973
Epoch 2/20
184/184 [==============================] - 400s 2s/step - loss: 0.0751 - out1_loss: 0.0111 - out2_loss: 0.0071 - out3_loss: 0.0132 - out4_loss: 0.0125 - out_item_loss: 0.0313 - out1_mae: 0.0831 - out1_acc: 0.0000e+00 - out2_mae: 0.0657 - out2_acc: 0.0000e+00 - out3_mae: 0.0899 - out3_acc: 0.0085 - out4_mae: 0.0878 - out4_acc: 0.0034 - out_item_mae: 0.0167 - out_item_acc: 0.9901 - val_loss: 0.0923 - val_out1_loss: 0.0098 - val_out2_loss: 0.0047 - val_out3_loss: 0.0400 - val_out4_loss: 0.0230 - val_out_item_loss: 0.0149 - val_out1_mae: 0.0802 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0540 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.1778 - val_out3_acc: 0.0122 - val_out4_mae: 0.1267 - val_out4_acc: 0.0027 - val_out_item_mae: 0.0051 - val_out_item_acc: 0.9959
Epoch 3/20
184/184 [==============================] - 408s 2s/step - loss: 0.0460 - out1_loss: 0.0080 - out2_loss: 0.0050 - out3_loss: 0.0098 - out4_loss: 0.0082 - out_item_loss: 0.0151 - out1_mae: 0.0698 - out1_acc: 0.0000e+00 - out2_mae: 0.0555 - out2_acc: 0.0000e+00 - out3_mae: 0.0776 - out3_acc: 0.0078 - out4_mae: 0.0711 - out4_acc: 0.0031 - out_item_mae: 0.0074 - out_item_acc: 0.9959 - val_loss: 0.0274 - val_out1_loss: 0.0064 - val_out2_loss: 0.0037 - val_out3_loss: 0.0077 - val_out4_loss: 0.0075 - val_out_item_loss: 0.0021 - val_out1_mae: 0.0622 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0479 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0677 - val_out3_acc: 0.0054 - val_out4_mae: 0.0662 - val_out4_acc: 0.0027 - val_out_item_mae: 0.0021 - val_out_item_acc: 1.0000
Epoch 4/20
184/184 [==============================] - 379s 2s/step - loss: 0.0440 - out1_loss: 0.0064 - out2_loss: 0.0043 - out3_loss: 0.0087 - out4_loss: 0.0077 - out_item_loss: 0.0169 - out1_mae: 0.0618 - out1_acc: 0.0000e+00 - out2_mae: 0.0513 - out2_acc: 0.0000e+00 - out3_mae: 0.0720 - out3_acc: 0.0078 - out4_mae: 0.0686 - out4_acc: 0.0037 - out_item_mae: 0.0096 - out_item_acc: 0.9929 - val_loss: 0.0239 - val_out1_loss: 0.0033 - val_out2_loss: 0.0023 - val_out3_loss: 0.0060 - val_out4_loss: 0.0074 - val_out_item_loss: 0.0049 - val_out1_mae: 0.0429 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0380 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0586 - val_out3_acc: 0.0122 - val_out4_mae: 0.0645 - val_out4_acc: 0.0041 - val_out_item_mae: 0.0028 - val_out_item_acc: 0.9986
Epoch 5/20
184/184 [==============================] - 421s 2s/step - loss: 0.0355 - out1_loss: 0.0051 - out2_loss: 0.0033 - out3_loss: 0.0064 - out4_loss: 0.0057 - out_item_loss: 0.0149 - out1_mae: 0.0561 - out1_acc: 0.0000e+00 - out2_mae: 0.0454 - out2_acc: 0.0000e+00 - out3_mae: 0.0623 - out3_acc: 0.0078 - out4_mae: 0.0595 - out4_acc: 0.0048 - out_item_mae: 0.0069 - out_item_acc: 0.9966 - val_loss: 0.0493 - val_out1_loss: 0.0034 - val_out2_loss: 0.0029 - val_out3_loss: 0.0247 - val_out4_loss: 0.0092 - val_out_item_loss: 0.0091 - val_out1_mae: 0.0451 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0425 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.1334 - val_out3_acc: 0.0082 - val_out4_mae: 0.0734 - val_out4_acc: 0.0095 - val_out_item_mae: 0.0048 - val_out_item_acc: 0.9959
Epoch 6/20
184/184 [==============================] - 424s 2s/step - loss: 0.0281 - out1_loss: 0.0044 - out2_loss: 0.0031 - out3_loss: 0.0061 - out4_loss: 0.0055 - out_item_loss: 0.0091 - out1_mae: 0.0515 - out1_acc: 0.0000e+00 - out2_mae: 0.0441 - out2_acc: 0.0000e+00 - out3_mae: 0.0613 - out3_acc: 0.0082 - out4_mae: 0.0576 - out4_acc: 0.0048 - out_item_mae: 0.0047 - out_item_acc: 0.9973 - val_loss: 0.0124 - val_out1_loss: 0.0026 - val_out2_loss: 0.0017 - val_out3_loss: 0.0035 - val_out4_loss: 0.0033 - val_out_item_loss: 0.0013 - val_out1_mae: 0.0382 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0318 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0429 - val_out3_acc: 0.0068 - val_out4_mae: 0.0432 - val_out4_acc: 0.0068 - val_out_item_mae: 0.0012 - val_out_item_acc: 1.0000
Epoch 7/20
184/184 [==============================] - 395s 2s/step - loss: 0.0174 - out1_loss: 0.0033 - out2_loss: 0.0027 - out3_loss: 0.0046 - out4_loss: 0.0044 - out_item_loss: 0.0025 - out1_mae: 0.0450 - out1_acc: 0.0000e+00 - out2_mae: 0.0409 - out2_acc: 0.0000e+00 - out3_mae: 0.0536 - out3_acc: 0.0075 - out4_mae: 0.0516 - out4_acc: 0.0037 - out_item_mae: 0.0018 - out_item_acc: 0.9993 - val_loss: 0.0113 - val_out1_loss: 0.0026 - val_out2_loss: 0.0015 - val_out3_loss: 0.0040 - val_out4_loss: 0.0030 - val_out_item_loss: 1.6839e-04 - val_out1_mae: 0.0401 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0304 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0508 - val_out3_acc: 0.0109 - val_out4_mae: 0.0410 - val_out4_acc: 0.0068 - val_out_item_mae: 1.6825e-04 - val_out_item_acc: 1.0000
Epoch 8/20
184/184 [==============================] - 413s 2s/step - loss: 0.0235 - out1_loss: 0.0039 - out2_loss: 0.0023 - out3_loss: 0.0047 - out4_loss: 0.0040 - out_item_loss: 0.0085 - out1_mae: 0.0491 - out1_acc: 0.0000e+00 - out2_mae: 0.0375 - out2_acc: 0.0000e+00 - out3_mae: 0.0534 - out3_acc: 0.0085 - out4_mae: 0.0499 - out4_acc: 0.0037 - out_item_mae: 0.0028 - out_item_acc: 0.9983 - val_loss: 0.0539 - val_out1_loss: 0.0033 - val_out2_loss: 0.0041 - val_out3_loss: 0.0152 - val_out4_loss: 0.0109 - val_out_item_loss: 0.0205 - val_out1_mae: 0.0431 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0503 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.1060 - val_out3_acc: 0.0082 - val_out4_mae: 0.0869 - val_out4_acc: 0.0041 - val_out_item_mae: 0.0092 - val_out_item_acc: 0.9932
Epoch 9/20
184/184 [==============================] - 407s 2s/step - loss: 0.0282 - out1_loss: 0.0031 - out2_loss: 0.0023 - out3_loss: 0.0046 - out4_loss: 0.0040 - out_item_loss: 0.0142 - out1_mae: 0.0442 - out1_acc: 0.0000e+00 - out2_mae: 0.0371 - out2_acc: 0.0000e+00 - out3_mae: 0.0538 - out3_acc: 0.0078 - out4_mae: 0.0495 - out4_acc: 0.0034 - out_item_mae: 0.0063 - out_item_acc: 0.9959 - val_loss: 0.0248 - val_out1_loss: 0.0026 - val_out2_loss: 0.0014 - val_out3_loss: 0.0139 - val_out4_loss: 0.0038 - val_out_item_loss: 0.0031 - val_out1_mae: 0.0382 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0284 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.1003 - val_out3_acc: 0.0068 - val_out4_mae: 0.0459 - val_out4_acc: 0.0054 - val_out_item_mae: 0.0016 - val_out_item_acc: 0.9986
Epoch 10/20
184/184 [==============================] - 402s 2s/step - loss: 0.0134 - out1_loss: 0.0037 - out2_loss: 0.0019 - out3_loss: 0.0039 - out4_loss: 0.0030 - out_item_loss: 8.8005e-04 - out1_mae: 0.0477 - out1_acc: 0.0000e+00 - out2_mae: 0.0338 - out2_acc: 0.0000e+00 - out3_mae: 0.0493 - out3_acc: 0.0082 - out4_mae: 0.0429 - out4_acc: 0.0048 - out_item_mae: 8.1029e-04 - out_item_acc: 1.0000 - val_loss: 0.0118 - val_out1_loss: 0.0034 - val_out2_loss: 0.0011 - val_out3_loss: 0.0049 - val_out4_loss: 0.0024 - val_out_item_loss: 1.1599e-04 - val_out1_mae: 0.0468 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0250 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0528 - val_out3_acc: 0.0054 - val_out4_mae: 0.0368 - val_out4_acc: 0.0014 - val_out_item_mae: 1.1586e-04 - val_out_item_acc: 1.0000
Epoch 11/20
184/184 [==============================] - 397s 2s/step - loss: 0.0177 - out1_loss: 0.0029 - out2_loss: 0.0021 - out3_loss: 0.0038 - out4_loss: 0.0037 - out_item_loss: 0.0053 - out1_mae: 0.0421 - out1_acc: 0.0000e+00 - out2_mae: 0.0355 - out2_acc: 0.0000e+00 - out3_mae: 0.0490 - out3_acc: 0.0088 - out4_mae: 0.0476 - out4_acc: 0.0041 - out_item_mae: 0.0033 - out_item_acc: 0.9973 - val_loss: 0.0132 - val_out1_loss: 0.0023 - val_out2_loss: 0.0015 - val_out3_loss: 0.0038 - val_out4_loss: 0.0052 - val_out_item_loss: 4.3660e-04 - val_out1_mae: 0.0383 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0296 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0487 - val_out3_acc: 0.0068 - val_out4_mae: 0.0543 - val_out4_acc: 0.0054 - val_out_item_mae: 4.3213e-04 - val_out_item_acc: 1.0000
Epoch 12/20
184/184 [==============================] - 385s 2s/step - loss: 0.0182 - out1_loss: 0.0031 - out2_loss: 0.0019 - out3_loss: 0.0033 - out4_loss: 0.0031 - out_item_loss: 0.0068 - out1_mae: 0.0430 - out1_acc: 0.0000e+00 - out2_mae: 0.0341 - out2_acc: 0.0000e+00 - out3_mae: 0.0455 - out3_acc: 0.0082 - out4_mae: 0.0435 - out4_acc: 0.0044 - out_item_mae: 0.0033 - out_item_acc: 0.9980 - val_loss: 0.0533 - val_out1_loss: 0.0100 - val_out2_loss: 0.0031 - val_out3_loss: 0.0211 - val_out4_loss: 0.0153 - val_out_item_loss: 0.0038 - val_out1_mae: 0.0685 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0422 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0879 - val_out3_acc: 0.0041 - val_out4_mae: 0.0747 - val_out4_acc: 0.0041 - val_out_item_mae: 0.0026 - val_out_item_acc: 0.9986
Epoch 13/20
184/184 [==============================] - 380s 2s/step - loss: 0.0176 - out1_loss: 0.0032 - out2_loss: 0.0018 - out3_loss: 0.0037 - out4_loss: 0.0029 - out_item_loss: 0.0061 - out1_mae: 0.0443 - out1_acc: 0.0000e+00 - out2_mae: 0.0329 - out2_acc: 0.0000e+00 - out3_mae: 0.0479 - out3_acc: 0.0082 - out4_mae: 0.0418 - out4_acc: 0.0041 - out_item_mae: 0.0023 - out_item_acc: 0.9990 - val_loss: 0.0119 - val_out1_loss: 0.0026 - val_out2_loss: 0.0011 - val_out3_loss: 0.0047 - val_out4_loss: 0.0035 - val_out_item_loss: 1.2042e-04 - val_out1_mae: 0.0390 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0250 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0511 - val_out3_acc: 0.0122 - val_out4_mae: 0.0437 - val_out4_acc: 0.0027 - val_out_item_mae: 1.2033e-04 - val_out_item_acc: 1.0000
Epoch 14/20
184/184 [==============================] - 380s 2s/step - loss: 0.0158 - out1_loss: 0.0024 - out2_loss: 0.0017 - out3_loss: 0.0032 - out4_loss: 0.0028 - out_item_loss: 0.0057 - out1_mae: 0.0380 - out1_acc: 0.0000e+00 - out2_mae: 0.0327 - out2_acc: 0.0000e+00 - out3_mae: 0.0445 - out3_acc: 0.0095 - out4_mae: 0.0412 - out4_acc: 0.0044 - out_item_mae: 0.0025 - out_item_acc: 0.9986 - val_loss: 0.0155 - val_out1_loss: 0.0031 - val_out2_loss: 0.0027 - val_out3_loss: 0.0046 - val_out4_loss: 0.0042 - val_out_item_loss: 9.4620e-04 - val_out1_mae: 0.0416 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0439 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0497 - val_out3_acc: 0.0122 - val_out4_mae: 0.0532 - val_out4_acc: 0.0027 - val_out_item_mae: 8.0296e-04 - val_out_item_acc: 1.0000
Epoch 15/20
184/184 [==============================] - 399s 2s/step - loss: 0.0149 - out1_loss: 0.0023 - out2_loss: 0.0015 - out3_loss: 0.0028 - out4_loss: 0.0026 - out_item_loss: 0.0056 - out1_mae: 0.0374 - out1_acc: 0.0000e+00 - out2_mae: 0.0308 - out2_acc: 0.0000e+00 - out3_mae: 0.0413 - out3_acc: 0.0078 - out4_mae: 0.0403 - out4_acc: 0.0044 - out_item_mae: 0.0021 - out_item_acc: 0.9986 - val_loss: 0.0109 - val_out1_loss: 0.0021 - val_out2_loss: 0.0015 - val_out3_loss: 0.0038 - val_out4_loss: 0.0030 - val_out_item_loss: 4.1310e-04 - val_out1_mae: 0.0343 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0308 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0481 - val_out3_acc: 0.0068 - val_out4_mae: 0.0395 - val_out4_acc: 0.0068 - val_out_item_mae: 4.1129e-04 - val_out_item_acc: 1.0000
Epoch 16/20
184/184 [==============================] - 344s 2s/step - loss: 0.0148 - out1_loss: 0.0021 - out2_loss: 0.0017 - out3_loss: 0.0030 - out4_loss: 0.0026 - out_item_loss: 0.0055 - out1_mae: 0.0363 - out1_acc: 0.0000e+00 - out2_mae: 0.0319 - out2_acc: 0.0000e+00 - out3_mae: 0.0432 - out3_acc: 0.0085 - out4_mae: 0.0396 - out4_acc: 0.0041 - out_item_mae: 0.0027 - out_item_acc: 0.9980 - val_loss: 0.0123 - val_out1_loss: 0.0020 - val_out2_loss: 0.0011 - val_out3_loss: 0.0050 - val_out4_loss: 0.0034 - val_out_item_loss: 6.8025e-04 - val_out1_mae: 0.0327 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0251 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0508 - val_out3_acc: 0.0095 - val_out4_mae: 0.0405 - val_out4_acc: 0.0041 - val_out_item_mae: 5.8624e-04 - val_out_item_acc: 1.0000
Epoch 17/20
184/184 [==============================] - 334s 2s/step - loss: 0.0180 - out1_loss: 0.0026 - out2_loss: 0.0017 - out3_loss: 0.0032 - out4_loss: 0.0028 - out_item_loss: 0.0076 - out1_mae: 0.0400 - out1_acc: 0.0000e+00 - out2_mae: 0.0322 - out2_acc: 0.0000e+00 - out3_mae: 0.0446 - out3_acc: 0.0075 - out4_mae: 0.0415 - out4_acc: 0.0027 - out_item_mae: 0.0036 - out_item_acc: 0.9980 - val_loss: 0.0113 - val_out1_loss: 0.0021 - val_out2_loss: 0.0014 - val_out3_loss: 0.0041 - val_out4_loss: 0.0031 - val_out_item_loss: 6.0786e-04 - val_out1_mae: 0.0326 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0302 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0394 - val_out3_acc: 0.0095 - val_out4_mae: 0.0396 - val_out4_acc: 0.0068 - val_out_item_mae: 5.8699e-04 - val_out_item_acc: 1.0000
Epoch 18/20
184/184 [==============================] - 332s 2s/step - loss: 0.0091 - out1_loss: 0.0023 - out2_loss: 0.0014 - out3_loss: 0.0025 - out4_loss: 0.0024 - out_item_loss: 4.3962e-04 - out1_mae: 0.0375 - out1_acc: 0.0000e+00 - out2_mae: 0.0290 - out2_acc: 0.0000e+00 - out3_mae: 0.0397 - out3_acc: 0.0085 - out4_mae: 0.0382 - out4_acc: 0.0048 - out_item_mae: 4.1533e-04 - out_item_acc: 1.0000 - val_loss: 0.0083 - val_out1_loss: 0.0016 - val_out2_loss: 0.0012 - val_out3_loss: 0.0032 - val_out4_loss: 0.0023 - val_out_item_loss: 4.4895e-05 - val_out1_mae: 0.0282 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0250 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0439 - val_out3_acc: 0.0054 - val_out4_mae: 0.0339 - val_out4_acc: 0.0054 - val_out_item_mae: 4.4879e-05 - val_out_item_acc: 1.0000
Epoch 19/20
184/184 [==============================] - 333s 2s/step - loss: 0.0077 - out1_loss: 0.0022 - out2_loss: 0.0014 - out3_loss: 0.0021 - out4_loss: 0.0020 - out_item_loss: 1.6014e-04 - out1_mae: 0.0359 - out1_acc: 0.0000e+00 - out2_mae: 0.0287 - out2_acc: 0.0000e+00 - out3_mae: 0.0357 - out3_acc: 0.0078 - out4_mae: 0.0346 - out4_acc: 0.0041 - out_item_mae: 1.5708e-04 - out_item_acc: 1.0000 - val_loss: 0.0049 - val_out1_loss: 8.3605e-04 - val_out2_loss: 9.8815e-04 - val_out3_loss: 0.0016 - val_out4_loss: 0.0015 - val_out_item_loss: 2.3111e-05 - val_out1_mae: 0.0220 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0233 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0306 - val_out3_acc: 0.0054 - val_out4_mae: 0.0281 - val_out4_acc: 0.0014 - val_out_item_mae: 2.3108e-05 - val_out_item_acc: 1.0000
Epoch 20/20
184/184 [==============================] - 335s 2s/step - loss: 0.0065 - out1_loss: 0.0015 - out2_loss: 0.0013 - out3_loss: 0.0018 - out4_loss: 0.0018 - out_item_loss: 1.4812e-04 - out1_mae: 0.0305 - out1_acc: 0.0000e+00 - out2_mae: 0.0285 - out2_acc: 0.0000e+00 - out3_mae: 0.0334 - out3_acc: 0.0085 - out4_mae: 0.0333 - out4_acc: 0.0041 - out_item_mae: 1.4408e-04 - out_item_acc: 1.0000 - val_loss: 0.0060 - val_out1_loss: 0.0017 - val_out2_loss: 0.0011 - val_out3_loss: 0.0019 - val_out4_loss: 0.0012 - val_out_item_loss: 2.1330e-05 - val_out1_mae: 0.0325 - val_out1_acc: 0.0000e+00 - val_out2_mae: 0.0253 - val_out2_acc: 0.0000e+00 - val_out3_mae: 0.0328 - val_out3_acc: 0.0054 - val_out4_mae: 0.0244 - val_out4_acc: 0.0027 - val_out_item_mae: 2.1332e-05 - val_out_item_acc: 1.0000

4. 模型评估

model.save("detect_v1.h5")
new_model = tf.keras.models.load_model("detect_v1.h5")
plt.figure(figsize=(8,24))
for img,_ in test_ds.skip(8).take(1):
    out1,out2,out3,out4,out5 = new_model.predict(img)
    for i in range(6):
        plt.subplot(6,1,i+1)
        plt.imshow(tf.keras.preprocessing.image.array_to_img(img[i]))
        xmin,ymin,xmax,ymax = out1[i]*224,out2[i]*224,out3[i]*224,out4[i]*224
        #给定左下角坐标
        rect = Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),fill=False,color = "red")
        ax = plt.gca()
        ax.axes.add_patch(rect)
        plt.title((class_label[round(float(out5[i]))]).title())
       #plt.show()

从图片结果看,整体的效果还是不错的。

相关实践学习
基于MSE实现微服务的全链路灰度
通过本场景的实验操作,您将了解并实现在线业务的微服务全链路灰度能力。
相关文章
【图像分类数据集】非常全面实用的垃圾分类图片数据集共享
【图像分类数据集】非常全面实用的垃圾分类图片数据集共享
974 25
【图像分类数据集】非常全面实用的垃圾分类图片数据集共享
|
人工智能 数据可视化 数据处理
快速在 PaddleLabel 标注的花朵分类数据集上展示如何应用 PaddleX 训练 MobileNetV3_ssld 网络
快速在 PaddleLabel 标注的花朵分类数据集上展示如何应用 PaddleX 训练 MobileNetV3_ssld 网络
811 0
快速在 PaddleLabel 标注的花朵分类数据集上展示如何应用 PaddleX 训练 MobileNetV3_ssld 网络
|
3月前
|
JSON 计算机视觉 数据格式
数据集学习笔记(一):常用检测、行为检测数据集
这篇文章是关于常用目标检测和行为检测数据集的介绍,包括CIFAR系列、COCO、VOC系列、TT100K和UCF101等数据集的详细信息和使用说明。
176 0
数据集学习笔记(一):常用检测、行为检测数据集
|
3月前
|
机器学习/深度学习 JSON 算法
语义分割笔记(二):DeepLab V3对图像进行分割(自定义数据集从零到一进行训练、验证和测试)
本文介绍了DeepLab V3在语义分割中的应用,包括数据集准备、模型训练、测试和评估,提供了代码和资源链接。
439 0
语义分割笔记(二):DeepLab V3对图像进行分割(自定义数据集从零到一进行训练、验证和测试)
|
3月前
|
数据采集
遥感语义分割数据集中的切图策略
该脚本用于遥感图像的切图处理,支持大尺寸图像按指定大小和步长切割为多个小图,适用于语义分割任务的数据预处理。通过设置剪裁尺寸(cs)和步长(ss),可灵活调整输出图像的数量和大小。此外,脚本还支持标签图像的转换,便于后续模型训练使用。
27 0
|
5月前
|
数据采集 机器学习/深度学习 算法
5.2.3 检测头设计(计算预测框位置和类别)
这篇文章详细介绍了YOLOv3目标检测模型中的检测头设计,包括预测框是否包含物体的概率计算、预测物体的位置和形状、预测物体类别的概率,并展示了如何通过网络输出得到预测值,以及如何建立损失函数来训练模型。
|
存储 传感器 数据可视化
3D目标检测数据集 KITTI(标签格式解析、3D框可视化、点云转图像、BEV鸟瞰图)
本文介绍在3D目标检测中,理解和使用KITTI 数据集,包括KITTI 的基本情况、下载数据集、标签格式解析、3D框可视化、点云转图像、画BEV鸟瞰图等,并配有实现代码。
1949 1
|
8月前
|
计算机视觉
论文介绍:像素级分类并非语义分割的唯一选择
【5月更文挑战第24天】论文《像素级分类并非语义分割的唯一选择》提出了MaskFormer模型,该模型通过掩模分类简化语义与实例级分割任务,无需修改模型结构、损失函数或训练过程。在ADE20K和COCO数据集上取得优异性能,显示处理大量类别时的优势。MaskFormer结合像素级、Transformer和分割模块,提高效率和泛化能力。掩模分类方法对比边界框匹配更具效率,且MaskFormer的掩模头设计降低计算成本。该方法为语义分割提供新思路,但实际应用与小物体处理仍有待检验。[链接](https://arxiv.org/abs/2107.06278)
69 3
|
JSON 算法 数据格式
优化cv2.findContours()函数提取的目标边界点,使语义分割进行远监督辅助标注
可以看到cv2.findContours()函数可以将目标的所有边界点都进行导出来,但是他的点存在一个问题,太过密集,如果我们想将语义分割的结果重新导出成labelme格式的json文件进行修正时,这就会存在点太密集没有办法进行修改,这里展示一个示例:没有对导出的结果进行修正,在labelme中的效果图。
250 0
|
测试技术 TensorFlow 算法框架/工具
NSFW 图片分类
NSFW指的是 **不适宜工作场所**("Not Safe (or Suitable) For Work;")。在本文中,将介绍如何创建一个检测NSFW图像的图像分类模型。
266 0