1. 图片预处理
在这一部分我们采用from_tensor_slices
的方法对图片数据集进行构建,对比tf1.x版本采用队列形式读取数据,这一种方法比较简单切易于理解。
#构建一个tf.data.Dataset #一个图片数据集构建 tf.data.Dataset 最简单的方法就是使用 from_tensor_slices 方法。 #将字符串数组切片,得到一个字符串数据集: train_path_ds = tf.data.Dataset.from_tensor_slices(train_all_image_path) print(train_path_ds) test_path_ds = tf.data.Dataset.from_tensor_slices(test_all_image_path) #现在创建一个新的数据集,通过在路径数据集上映射preprocess_image来动态加载和格式化图片。 AUTOTUNE = tf.data.experimental.AUTOTUNE train_image_ds = train_path_ds.map(load_and_preprocess_image,num_parallel_calls=AUTOTUNE) test_image_ds = test_path_ds.map(load_and_preprocess_image,num_parallel_calls=AUTOTUNE) train_lable_ds = tf.data.Dataset.from_tensor_slices(tf.cast(train_image_label,tf.int64)) test_lable_ds = tf.data.Dataset.from_tensor_slices(tf.cast(test_image_label,tf.int64)) for label in train_lable_ds.take(5): print(lable_names[label.numpy()]) #%%构建一个(图片,标签)对数据集 #因为这些数据集顺序相同,可以将他们打包起来 image_label_ds = tf.data.Dataset.zip((train_image_ds,train_lable_ds)) test_data = tf.data.Dataset.zip((test_image_ds,test_lable_ds)) print(test_data)
<TensorSliceDataset shapes: (), types: tf.string> cat dog dog cat dog <ZipDataset shapes: ((256, 256, 3), ()), types: (tf.float32, tf.int64)>
2. 训练阶段
batch_size = 32 # 设置一个和数据集大小一致的 shuffle buffer size(随机缓冲区大小)以保证数据被充分打乱。 train_ds = image_label_ds.shuffle(buffer_size=train_image_count).batch(batch_size) test_ds = test_data.batch(batch_size) train_ds = train_ds.prefetch(buffer_size=AUTOTUNE) test_ds = test_ds.prefetch(buffer_size=AUTOTUNE)
2.1 构建模型并训练
model = tf.keras.Sequential() #顺序模型 model.add(tf.keras.layers.Conv2D(64, (3, 3), input_shape=(256, 256, 3), activation='relu')) model.add(tf.keras.layers.BatchNormalization()) #model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu')) model.add(tf.keras.layers.MaxPooling2D()) #model.add(tf.keras.layers.Conv2D(128, (3, 3), activation='relu')) model.add(tf.keras.layers.Conv2D(128, (3, 3), activation='relu')) model.add(tf.keras.layers.BatchNormalization()) model.add(tf.keras.layers.MaxPooling2D()) #model.add(tf.keras.layers.Conv2D(256, (3, 3), activation='relu')) model.add(tf.keras.layers.Conv2D(256, (3, 3), activation='relu')) model.add(tf.keras.layers.BatchNormalization()) model.add(tf.keras.layers.MaxPooling2D()) model.add(tf.keras.layers.Conv2D(512, (3, 3), activation='relu')) model.add(tf.keras.layers.BatchNormalization()) model.add(tf.keras.layers.MaxPooling2D()) model.add(tf.keras.layers.Conv2D(512, (3, 3), activation='relu')) model.add(tf.keras.layers.BatchNormalization()) model.add(tf.keras.layers.MaxPooling2D()) model.add(tf.keras.layers.Conv2D(1024, (3, 3), activation='relu')) model.add(tf.keras.layers.BatchNormalization()) model.add(tf.keras.layers.GlobalAveragePooling2D()) model.add(tf.keras.layers.Dense(1024, activation='relu')) model.add(tf.keras.layers.BatchNormalization()) model.add(tf.keras.layers.Dense(256, activation='relu')) model.add(tf.keras.layers.BatchNormalization()) model.add(tf.keras.layers.Dense(1)) #%% model.summary()
Model: "sequential_8" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d_66 (Conv2D) (None, 254, 254, 64) 1792 _________________________________________________________________ batch_normalization (BatchNo (None, 254, 254, 64) 256 _________________________________________________________________ max_pooling2d_40 (MaxPooling (None, 127, 127, 64) 0 _________________________________________________________________ conv2d_67 (Conv2D) (None, 125, 125, 128) 73856 _________________________________________________________________ batch_normalization_1 (Batch (None, 125, 125, 128) 512 _________________________________________________________________ max_pooling2d_41 (MaxPooling (None, 62, 62, 128) 0 _________________________________________________________________ conv2d_68 (Conv2D) (None, 60, 60, 256) 295168 _________________________________________________________________ batch_normalization_2 (Batch (None, 60, 60, 256) 1024 _________________________________________________________________ max_pooling2d_42 (MaxPooling (None, 30, 30, 256) 0 _________________________________________________________________ conv2d_69 (Conv2D) (None, 28, 28, 512) 1180160 _________________________________________________________________ batch_normalization_3 (Batch (None, 28, 28, 512) 2048 _________________________________________________________________ max_pooling2d_43 (MaxPooling (None, 14, 14, 512) 0 _________________________________________________________________ conv2d_70 (Conv2D) (None, 12, 12, 512) 2359808 _________________________________________________________________ batch_normalization_4 (Batch (None, 12, 12, 512) 2048 _________________________________________________________________ max_pooling2d_44 (MaxPooling (None, 6, 6, 512) 0 _________________________________________________________________ conv2d_71 (Conv2D) (None, 4, 4, 1024) 4719616 _________________________________________________________________ batch_normalization_5 (Batch (None, 4, 4, 1024) 4096 _________________________________________________________________ global_average_pooling2d_8 ( (None, 1024) 0 _________________________________________________________________ dense_24 (Dense) (None, 1024) 1049600 _________________________________________________________________ batch_normalization_6 (Batch (None, 1024) 4096 _________________________________________________________________ dense_25 (Dense) (None, 256) 262400 _________________________________________________________________ batch_normalization_7 (Batch (None, 256) 1024 _________________________________________________________________ dense_26 (Dense) (None, 1) 257 ================================================================= Total params: 9,957,761 Trainable params: 9,950,209 Non-trainable params: 7,552 _________________________________________________________________
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001) #自定义损失,Sparse是可调用的对象 loss_fuc = tf.keras.losses.BinaryCrossentropy(from_logits=True)# 这是可调用的方法,因为我们没有加入激活函数,所以from_logits=true #定义损失 def loss(model,x,y): y_ = model(x) y = tf.expand_dims(tf.cast(y,dtype=tf.float32),axis=1) return loss_fuc(y_,y) #定义平均训练集和测试集平均损失和准确率 train_loss = tf.keras.metrics.Mean("train_loss") train_accuracy = tf.keras.metrics.Accuracy("train_accuracy") test_loss = tf.keras.metrics.Mean("test_loss") test_accuracy = tf.keras.metrics.Accuracy("test_accuracy") #定义train_step def train_step(model,image,labels): with tf.GradientTape() as t: pred = model(image) loss_step = loss_fuc(tf.expand_dims(tf.cast(labels,dtype=tf.float32),axis=1),pred) grads = t.gradient(loss_step,model.trainable_variables) optimizer.apply_gradients(zip(grads,model.trainable_variables)) train_loss(loss_step) train_accuracy(labels,tf.cast(pred>0,tf.int64)) #定义train_step def test_step(model,image,labels): pred = model(image) loss_step = loss_fuc(tf.expand_dims(tf.cast(labels,dtype=tf.float32),axis=1),pred) test_loss(loss_step) test_accuracy(labels,tf.cast(pred>0,tf.int64)) def train(model,train_ds,test_ds): train_loss_sca = [] test_loss_sca = [] train_acc_sca = [] test_acc_sca = [] for epoch in range(30): for (batch,(image,labels)) in enumerate(train_ds): #print(".") #进行异步训连 train_step(model,image,labels) for (batch,(image,labels)) in enumerate(test_ds): test_step(model,image,labels) train_loss_sca.append(train_loss.result()) test_loss_sca.append(test_loss.result()) train_acc_sca.append(train_accuracy.result()) test_acc_sca.append(test_accuracy.result()) print("epoch{} train_loss is {};train_accuracy is {};test_loss is {};test_accuracy is {}".format(epoch+1, train_loss.result(), train_accuracy.result(), test_loss.result(), test_accuracy.result() )) train_loss.reset_states() train_accuracy.reset_states() test_loss.reset_states() test_accuracy.reset_states() return (train_loss_sca,test_loss_sca,train_acc_sca,test_acc_sca) #训练 train_loss_sca,test_loss_sca,train_acc_sca,test_acc_sca = train(model,train_ds,test_ds)
epoch1 train_loss is 0.6930190324783325;train_accuracy is 0.5039578080177307;test_loss is 0.6916858553886414;test_accuracy is 0.5 epoch2 train_loss is 0.6807847023010254;train_accuracy is 0.5712401270866394;test_loss is 0.6401413083076477;test_accuracy is 0.6240000128746033 epoch3 train_loss is 0.6488288640975952;train_accuracy is 0.6144459247589111;test_loss is 0.6341478824615479;test_accuracy is 0.6399999856948853 epoch4 train_loss is 0.6222826838493347;train_accuracy is 0.6510553956031799;test_loss is 0.608300507068634;test_accuracy is 0.6790000200271606 epoch5 train_loss is 0.6008259654045105;train_accuracy is 0.6721636056900024;test_loss is 0.6023556590080261;test_accuracy is 0.6660000085830688 epoch6 train_loss is 0.5754649639129639;train_accuracy is 0.6952506303787231;test_loss is 0.5704880952835083;test_accuracy is 0.7039999961853027 epoch7 train_loss is 0.5785166025161743;train_accuracy is 0.6995382308959961;test_loss is 0.5662873983383179;test_accuracy is 0.7039999961853027 epoch8 train_loss is 0.5402986407279968;train_accuracy is 0.7272427678108215;test_loss is 0.5656307935714722;test_accuracy is 0.7049999833106995 epoch9 train_loss is 0.5293075442314148;train_accuracy is 0.7358179688453674;test_loss is 0.5394512414932251;test_accuracy is 0.7279999852180481 epoch10 train_loss is 0.5094398260116577;train_accuracy is 0.7569261193275452;test_loss is 0.5466101169586182;test_accuracy is 0.7369999885559082 epoch11 train_loss is 0.49413660168647766;train_accuracy is 0.7575857639312744;test_loss is 0.4992522597312927;test_accuracy is 0.7630000114440918 epoch12 train_loss is 0.46310922503471375;train_accuracy is 0.7833113670349121;test_loss is 0.48759791254997253;test_accuracy is 0.7620000243186951 epoch13 train_loss is 0.43853995203971863;train_accuracy is 0.7974933981895447;test_loss is 0.4716934263706207;test_accuracy is 0.7950000166893005 epoch14 train_loss is 0.3964250981807709;train_accuracy is 0.8258575201034546;test_loss is 0.48196572065353394;test_accuracy is 0.781000018119812 epoch15 train_loss is 0.3641780912876129;train_accuracy is 0.8390501141548157;test_loss is 0.4671226739883423;test_accuracy is 0.8080000281333923 epoch16 train_loss is 0.3265441060066223;train_accuracy is 0.857189953327179;test_loss is 0.4624425172805786;test_accuracy is 0.7960000038146973 epoch17 train_loss is 0.3088551461696625;train_accuracy is 0.8680738806724548;test_loss is 0.48780468106269836;test_accuracy is 0.7919999957084656 epoch18 train_loss is 0.2565159499645233;train_accuracy is 0.899076521396637;test_loss is 0.4692841172218323;test_accuracy is 0.8169999718666077 epoch19 train_loss is 0.23743189871311188;train_accuracy is 0.9027044773101807;test_loss is 0.4935546815395355;test_accuracy is 0.8209999799728394 epoch20 train_loss is 0.22497089207172394;train_accuracy is 0.9099604487419128;test_loss is 0.514469563961029;test_accuracy is 0.8100000023841858 epoch21 train_loss is 0.16318537294864655;train_accuracy is 0.9370052814483643;test_loss is 0.5123884677886963;test_accuracy is 0.8330000042915344 epoch22 train_loss is 0.13359478116035461;train_accuracy is 0.9495382308959961;test_loss is 0.5613532066345215;test_accuracy is 0.8190000057220459 epoch23 train_loss is 0.10728871077299118;train_accuracy is 0.9617414474487305;test_loss is 0.5727055668830872;test_accuracy is 0.8399999737739563 epoch24 train_loss is 0.08103378862142563;train_accuracy is 0.9696570038795471;test_loss is 0.608410656452179;test_accuracy is 0.8320000171661377 epoch25 train_loss is 0.04331161081790924;train_accuracy is 0.9858179688453674;test_loss is 0.7248072624206543;test_accuracy is 0.8309999704360962 epoch26 train_loss is 0.06562910228967667;train_accuracy is 0.977902352809906;test_loss is 0.7443087100982666;test_accuracy is 0.777999997138977 epoch27 train_loss is 0.05616709962487221;train_accuracy is 0.9818601608276367;test_loss is 0.7444125413894653;test_accuracy is 0.8190000057220459 epoch28 train_loss is 0.02497768960893154;train_accuracy is 0.9910950064659119;test_loss is 0.8051329851150513;test_accuracy is 0.8180000185966492 epoch29 train_loss is 0.029464619234204292;train_accuracy is 0.9887862801551819;test_loss is 0.9197209477424622;test_accuracy is 0.8240000009536743 epoch30 train_loss is 0.012036236003041267;train_accuracy is 0.9970316886901855;test_loss is 0.8360891342163086;test_accuracy is 0.8450000286102295
2.2 模型评估
我们对比一下在训练集和验证集的准确度和损失的变化曲线,我们可以发现,本网络有些过拟合。
plt.plot([i for i in range(30)], train_loss_sca, label='train_loss') plt.plot([i for i in range(30)], test_loss_sca, label='test_aloss') plt.legend()
plt.plot([i for i in range(30)], train_acc_sca, label='traib_acc') plt.plot([i for i in range(30)], test_acc_sca, label='test_acc') plt.legend()
总结一下:本测试的结果有些过拟合,考虑到图片较少,不能提取更为高级的特征,所以在测试集的效果只有83%,后期我们可以通过增加数据集,数据增强,并引入正则化和丢弃法等方法来抑制过拟合并提高精度。
3.数据增强
我们更改一下读取图片的代码就可以实现,如下所示
#确定每个图像的标签 lable_names = sorted(item.name for item in data_dir.glob("train/*/")) #为每个标签分配索引,构建字典 lable_to_index = dict((name,index) for index,name in enumerate(lable_names)) print(lable_to_index) #创建一个列表,包含每个文件的标签索引 train_image_label = [lable_to_index[pathlib.Path(path).parent.name] for path in train_all_image_path] test_image_label = [lable_to_index[pathlib.Path(path).parent.name] for path in test_all_image_path] #包装为函数,以备后用,图片的预处理 def preprocess_image(image,is_train=True): image = tf.image.decode_jpeg(image, channels=3) if is_train: image = tf.image.resize(image, [360, 360]) #图像剪裁 image = tf.image.random_crop(image,[256,256,3]) #要写通道 image = tf.image.random_flip_left_right(image) #左右翻转 imgae = tf.image.random_flip_up_down(image) #上下翻转 #image = tf.image.random_brightness(image,0.5) #image = tf.image.random_contrast(image,0,1) else: image = tf.image.resize(image,[256,256]) image /= 255.0 # normalize to [0,1] range return image #加载图片 def load_and_preprocess_train_image(path,is_train=True): image = tf.io.read_file(path) return preprocess_image(image,is_train) def load_and_preprocess_test_image(path,is_train=False): image = tf.io.read_file(path) return preprocess_image(image,is_train) image_path = test_all_image_path[11] label = test_image_label[11] plt.imshow(load_and_preprocess_train_image(image_path)) plt.grid(False) ##plt.xlabel(caption_image(image_path)) plt.title(lable_names[label].title()) plt.axis("off") print()
训练集和验证集的准确度和损失的变化曲线,我们可以看出,不仅提高了准确率而且对过拟合也有较好的抑制作用。