自定义数据集
做深度学习项目时,我们一般都不用网上公开的数据集,而是用自己制作的数据集。那么,怎么用Tensorflow2.0来制作自己的数据集并把数据喂给神经网络呢?且看这篇文章慢慢道来。
Pokemon Datasets
这篇文章我们用的datasets是Pokemon datasets,也就是皮卡丘电影中的一些角色,如下图所示:
数据集下载
链接: https://pan.baidu.com/s/1V_ZJ7ufjUUFZwD2NHSNMFw
提取码:dsxl
数据集划分
由上图可知,60%的数据集用来train
,20%的数据集用来validation
,同样20%用来test
。
四个步骤
- Load data:加载数据
- Build model:建立模型
- Train-Val-Test:训练和测试
- Transfer Learning:迁移模型
加载数据
首先对数据进行预处理,把像素值的Numpy类型转换为Tensor类型,并归一化到[0~1]。把数据集的标签做one-hot
编码。
def preprocess(x,y): # x: 图片的路径,y:图片的数字编码 x = tf.io.read_file(x) x = tf.image.decode_jpeg(x, channels=3) # RGBA x = tf.image.resize(x, [244, 244]) return x, y
数据集标准处理流程
代码中load_pokemon
用的是自己的数据集写的代码,具体可阅读pokemon.py
文件。
# 创建训练集Datset对象 images, labels, table = load_pokemon('pokemon',mode='train') db_train = tf.data.Dataset.from_tensor_slices((images, labels)) db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz) # 创建验证集Datset对象 images2, labels2, table = load_pokemon('pokemon',mode='val') db_val = tf.data.Dataset.from_tensor_slices((images2, labels2)) db_val = db_val.map(preprocess).batch(batchsz) # 创建测试集Datset对象 images3, labels3, table = load_pokemon('pokemon',mode='test') db_test = tf.data.Dataset.from_tensor_slices((images3, labels3)) db_test = db_test.map(preprocess).batch(batchsz)
图片数据增强及标准化
一般数据集较少的话需要使用数据增强以增加数据集,防止训练网络过拟合。比如旋转角度、裁剪等,并归一化到[0~1]。把数据集的标签做one-hot
编码。所示代码如下:
# x = tf.image.random_flip_left_right(x) x = tf.image.random_flip_up_down(x) x = tf.image.random_crop(x, [224,224,3]) # x: [0,255]=> -1~1 x = tf.cast(x, dtype=tf.float32) / 255. x = normalize(x) y = tf.convert_to_tensor(y) y = tf.one_hot(y, depth=5)
建立网络
神经网络从零开始训练,backbone用李沐大神的resnet
网络。详细代码请查看resnet.py
文件。部分代码如下:
class ResNet(keras.Model): def __init__(self, num_classes, initial_filters=16, **kwargs): super(ResNet, self).__init__(**kwargs) self.stem = layers.Conv2D(initial_filters, 3, strides=3, padding='valid') self.blocks = keras.models.Sequential([ ResnetBlock(initial_filters * 2, strides=3), ResnetBlock(initial_filters * 2, strides=1), # layers.Dropout(rate=0.5), ResnetBlock(initial_filters * 4, strides=3), ResnetBlock(initial_filters * 4, strides=1), ResnetBlock(initial_filters * 8, strides=2), ResnetBlock(initial_filters * 8, strides=1), ResnetBlock(initial_filters * 16, strides=2), ResnetBlock(initial_filters * 16, strides=1), ]) self.final_bn = layers.BatchNormalization() self.avg_pool = layers.GlobalMaxPool2D() self.fc = layers.Dense(num_classes) def call(self, inputs, training=None): # print('x:',inputs.shape) out = self.stem(inputs,training=training) out = tf.nn.relu(out) # print('stem:',out.shape) out = self.blocks(out, training=training) # print('res:',out.shape) out = self.final_bn(out, training=training) # out = tf.nn.relu(out) out = self.avg_pool(out) # print('avg_pool:',out.shape) out = self.fc(out) # print('out:',out.shape) return out
训练和测试
部分代码如下:
resnet = keras.Sequential([ layers.Conv2D(16,5,3), layers.MaxPool2D(3,3), layers.ReLU(), layers.Conv2D(64,5,3), layers.MaxPool2D(2,2), layers.ReLU(), layers.Flatten(), layers.Dense(64), layers.ReLU(), layers.Dense(5) ]) resnet = ResNet(5) resnet.build(input_shape=(4, 224, 224, 3)) resnet.summary() early_stopping = EarlyStopping( monitor='val_accuracy', min_delta=0.001, patience=5 ) resnet.compile(optimizer=optimizers.Adam(lr=1e-3), loss=losses.CategoricalCrossentropy(from_logits=True), metrics=['accuracy']) resnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=100, callbacks=[early_stopping]) resnet.evaluate(db_test)
迁移网络学习
网络可以丛零开始训练,也可以从别的训练好的参数模型迁移过来,本次实战用Tensorflow
预训练的vgg19
模型来加载训练,从而加快训练过程。
迁移学习的原理如下图所示:
部分代码如下:
net = keras.applications.VGG19(weights='imagenet', include_top=False, pooling='max') net.trainable = False newnet = keras.Sequential([ net, layers.Dense(5) ]) newnet.build(input_shape=(4,224,224,3)) newnet.summary() early_stopping = EarlyStopping( monitor='val_accuracy', min_delta=0.001, patience=5 ) newnet.compile(optimizer=optimizers.Adam(lr=1e-3), loss=losses.CategoricalCrossentropy(from_logits=True), metrics=['accuracy']) newnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=100, callbacks=[early_stopping]) newnet.evaluate(db_test)
代码下载
本篇文章完整代码在公众号对话框回复 “pokemon” 就可得到百度云链接,建议直接复制再去公众号回复。
参考资料
本篇文章主要参考网易云课堂龙龙老师的《深度学习与TensorFlow 2入门实战》
课程链接:https://study.163.com/course/courseMain.htm?courseId=1209092816&share=1&shareId=1026182418