3. 使用手写数据集自定义网络
3.1 数据的预处理
这一部分主要是数据加载,维度扩充和归一化处理。
(train_image,train_labels),(test_image,test_labels) = tf.keras.datasets.mnist.load_data() #扩充维度,增加通道项 train_image = tf.expand_dims(train_image,-1) print(train_image.shape) test_image = tf.expand_dims(test_image,-1) print(train_image.shape) #对图像改变数据类型,归一化 train_image = tf.cast(train_image/255,tf.float32) train_labels = tf.cast(train_labels,tf.int64) test_image = tf.cast(test_image/255,tf.float32) test_labels = tf.cast(test_labels,tf.int64)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11493376/11490434 [==============================] - 0s 0us/step (60000, 28, 28, 1) (60000, 28, 28, 1)
3.2 数据批量化和网络构建
将数据批量化,batch_size = 32
dataset = tf.data.Dataset.from_tensor_slices((train_image,train_labels)) test_dataset = tf.data.Dataset.from_tensor_slices((test_image,test_labels)) dataset = dataset.shuffle(60000).batch(32) test_dataset = test_dataset.batch(32) print(dataset)
<BatchDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int64)>
构建网络
model = tf.keras.Sequential([ tf.keras.layers.Conv2D(16,[3,3],activation="relu",input_shape=(None,None,1)),#任意大小的channel都能输入进来 tf.keras.layers.Conv2D(32,[3,3],activation="relu"), tf.keras.layers.GlobalAveragePooling2D(), tf.keras.layers.Dense(10), ] ) model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d (Conv2D) (None, None, None, 16) 160 _________________________________________________________________ conv2d_1 (Conv2D) (None, None, None, 32) 4640 _________________________________________________________________ global_average_pooling2d (Gl (None, 32) 0 _________________________________________________________________ dense (Dense) (None, 10) 330 ================================================================= Total params: 5,130 Trainable params: 5,130 Non-trainable params: 0 _________________________________________________________________
我们可以利用model.trainable_variables
利用和查看过滤器的变量。
optimizer = tf.keras.optimizers.Adam() #自定义损失,Sparse是可调用的对象 loss_fuc = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)# 这是可调用的方法,因为我们没有加入激活函数,所以from_logits=true feature,label = next(iter(dataset))# 可以封装成迭代器直接调用 def loss(model,x,y): y_ = model(x) return loss_fuc(y,y_) loss(model,feature,label)
<tf.Tensor: shape=(), dtype=float32, numpy=2.3087442>
接下来我们定义损失,准确率并求取梯度。
train_loss = tf.keras.metrics.Mean("train_loss") train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy("train_accuracy") #求梯度 def train_step(model,image,labels): with tf.GradientTape() as t: pred = model(image) loss_step = loss_fuc(labels,pred) grads = t.gradient(loss_step,model.trainable_variables) optimizer.apply_gradients(zip(grads,model.trainable_variables)) train_loss(loss_step) train_accuracy(labels,pred)
3.3 训练预测
def train(): for epoch in range(10): for (batch,(image,labels)) in enumerate(dataset): #进行异步训连 train_step(model,image,labels) print("epoch{} loss is {};accuracy is {}".format(epoch, train_loss.result(), train_accuracy.result())) train_loss.reset_states() train_accuracy.reset_states() train()
epoch0 loss is 0.47583284974098206;accuracy is 0.8527083396911621 epoch1 loss is 0.45875340700149536;accuracy is 0.862500011920929 epoch2 loss is 0.44017791748046875;accuracy is 0.8687833547592163 epoch3 loss is 0.4235962927341461;accuracy is 0.8733333349227905 epoch4 loss is 0.4048921465873718;accuracy is 0.8791000247001648 epoch5 loss is 0.3935568332672119;accuracy is 0.8831833600997925 epoch6 loss is 0.38044092059135437;accuracy is 0.8866000175476074 epoch7 loss is 0.370032399892807;accuracy is 0.8890500068664551 epoch8 loss is 0.3582034409046173;accuracy is 0.8931166529655457 epoch9 loss is 0.34430235624313354;accuracy is 0.8981166481971741
我们也可以在训练中打印出test的数变化情况
train_loss = tf.keras.metrics.Mean("train_loss") train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy("train_accuracy") test_loss = tf.keras.metrics.Mean("test_loss") test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy("test_accuracy") def test_step(model,image,labels): pred = model(image) loss_step = loss_fuc(labels,pred) test_loss(loss_step) test_accuracy(labels,pred) def train(): for epoch in range(10): for (batch,(image,labels)) in enumerate(dataset): #进行异步训连 train_step(model,image,labels) for (batch,(image,labels)) in enumerate(test_dataset): test_step(model,image,labels) print("epoch{} train_loss is {};train_accuracy is {};test_loss is {};test_accuracy is {}".format(epoch, train_loss.result(), train_accuracy.result(), test_loss.result(), test_accuracy.result() )) train_loss.reset_states() train_accuracy.reset_states() train()
epoch0 train_loss is 0.3299615681171417;train_accuracy is 0.9027583599090576;test_loss is 0.3066971004009247;test_accuracy is 0.907800018787384 epoch1 train_loss is 0.3169953227043152;train_accuracy is 0.9060500264167786;test_loss is 0.2917846441268921;test_accuracy is 0.9154999852180481 epoch2 train_loss is 0.3101600706577301;train_accuracy is 0.9079499840736389;test_loss is 0.2940264344215393;test_accuracy is 0.9137333035469055 epoch3 train_loss is 0.30038899183273315;train_accuracy is 0.9114000201225281;test_loss is 0.28789690136909485;test_accuracy is 0.9157249927520752 epoch4 train_loss is 0.29241883754730225;train_accuracy is 0.9130833148956299;test_loss is 0.2802391052246094;test_accuracy is 0.9186400175094604 epoch5 train_loss is 0.28577837347984314;train_accuracy is 0.9151166677474976;test_loss is 0.2763482332229614;test_accuracy is 0.9198833107948303 epoch6 train_loss is 0.27776893973350525;train_accuracy is 0.918666660785675;test_loss is 0.2713969349861145;test_accuracy is 0.9215571284294128 epoch7 train_loss is 0.2718273401260376;train_accuracy is 0.9201499819755554;test_loss is 0.2703363001346588;test_accuracy is 0.9218875169754028 epoch8 train_loss is 0.26651278138160706;train_accuracy is 0.9215333461761475;test_loss is 0.27081072330474854;test_accuracy is 0.9211888909339905 epoch9 train_loss is 0.2612370252609253;train_accuracy is 0.9223999977111816;test_loss is 0.26694610714912415;test_accuracy is 0.922569990158081