tebsorflow2.0 eager模式与自定义训练网络(下)

简介: 对比tensorflow1.x版本静态图模式,tensorflow2.x推荐使用的是eager模式,即动态计算模式,它的特点是运算可以立即得到结果。我们可以通过tf.executing_eagerly()来判断是不是eager模式,如果返回的为True,使用的则为eager模式。首先我们简答介绍一下在eager模式下的计算。

3. 使用手写数据集自定义网络

3.1 数据的预处理

这一部分主要是数据加载,维度扩充和归一化处理。

(train_image,train_labels),(test_image,test_labels) = tf.keras.datasets.mnist.load_data() 
#扩充维度,增加通道项
train_image = tf.expand_dims(train_image,-1)
print(train_image.shape)
test_image = tf.expand_dims(test_image,-1)
print(train_image.shape)
#对图像改变数据类型,归一化
train_image = tf.cast(train_image/255,tf.float32)
train_labels = tf.cast(train_labels,tf.int64)
test_image = tf.cast(test_image/255,tf.float32)
test_labels = tf.cast(test_labels,tf.int64)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
(60000, 28, 28, 1)
(60000, 28, 28, 1)

3.2 数据批量化和网络构建

将数据批量化,batch_size = 32

dataset = tf.data.Dataset.from_tensor_slices((train_image,train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_image,test_labels))
dataset = dataset.shuffle(60000).batch(32)
test_dataset = test_dataset.batch(32)
print(dataset)
<BatchDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int64)>

构建网络

model = tf.keras.Sequential([
     tf.keras.layers.Conv2D(16,[3,3],activation="relu",input_shape=(None,None,1)),#任意大小的channel都能输入进来
     tf.keras.layers.Conv2D(32,[3,3],activation="relu"),
     tf.keras.layers.GlobalAveragePooling2D(),
     tf.keras.layers.Dense(10),
     ]
)
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, None, None, 16)    160       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, None, None, 32)    4640      
_________________________________________________________________
global_average_pooling2d (Gl (None, 32)                0         
_________________________________________________________________
dense (Dense)                (None, 10)                330       
=================================================================
Total params: 5,130
Trainable params: 5,130
Non-trainable params: 0
_________________________________________________________________

我们可以利用model.trainable_variables利用和查看过滤器的变量。

optimizer = tf.keras.optimizers.Adam()
#自定义损失,Sparse是可调用的对象
loss_fuc = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)# 这是可调用的方法,因为我们没有加入激活函数,所以from_logits=true
feature,label = next(iter(dataset))# 可以封装成迭代器直接调用
def loss(model,x,y):
  y_ = model(x)
  return loss_fuc(y,y_)
loss(model,feature,label)
<tf.Tensor: shape=(), dtype=float32, numpy=2.3087442>

接下来我们定义损失,准确率并求取梯度。

train_loss = tf.keras.metrics.Mean("train_loss")
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy("train_accuracy") 
#求梯度 
def train_step(model,image,labels):
  with tf.GradientTape() as t:
    pred = model(image)
    loss_step = loss_fuc(labels,pred)
  grads = t.gradient(loss_step,model.trainable_variables)
  optimizer.apply_gradients(zip(grads,model.trainable_variables))
  train_loss(loss_step)
  train_accuracy(labels,pred)

3.3 训练预测

def train():
  for epoch in range(10):
    for (batch,(image,labels)) in enumerate(dataset):
      #进行异步训连
      train_step(model,image,labels)
    print("epoch{} loss is {};accuracy is {}".format(epoch,
                               train_loss.result(),
                               train_accuracy.result()))
    train_loss.reset_states()
    train_accuracy.reset_states()
train()
epoch0 loss is 0.47583284974098206;accuracy is 0.8527083396911621
epoch1 loss is 0.45875340700149536;accuracy is 0.862500011920929
epoch2 loss is 0.44017791748046875;accuracy is 0.8687833547592163
epoch3 loss is 0.4235962927341461;accuracy is 0.8733333349227905
epoch4 loss is 0.4048921465873718;accuracy is 0.8791000247001648
epoch5 loss is 0.3935568332672119;accuracy is 0.8831833600997925
epoch6 loss is 0.38044092059135437;accuracy is 0.8866000175476074
epoch7 loss is 0.370032399892807;accuracy is 0.8890500068664551
epoch8 loss is 0.3582034409046173;accuracy is 0.8931166529655457
epoch9 loss is 0.34430235624313354;accuracy is 0.8981166481971741

我们也可以在训练中打印出test的数变化情况

train_loss = tf.keras.metrics.Mean("train_loss")
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy("train_accuracy") 
test_loss = tf.keras.metrics.Mean("test_loss")
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy("test_accuracy") 
def test_step(model,image,labels):
  pred = model(image)
  loss_step = loss_fuc(labels,pred)
  test_loss(loss_step)
  test_accuracy(labels,pred)
def train():
  for epoch in range(10):
    for (batch,(image,labels)) in enumerate(dataset):
      #进行异步训连
      train_step(model,image,labels)
    for (batch,(image,labels)) in enumerate(test_dataset):
      test_step(model,image,labels)
    print("epoch{} train_loss is {};train_accuracy is {};test_loss is {};test_accuracy is {}".format(epoch,
                               train_loss.result(),
                               train_accuracy.result(),
                               test_loss.result(),
                               test_accuracy.result()
                               ))
    train_loss.reset_states()
    train_accuracy.reset_states()
train()
epoch0 train_loss is 0.3299615681171417;train_accuracy is 0.9027583599090576;test_loss is 0.3066971004009247;test_accuracy is 0.907800018787384
epoch1 train_loss is 0.3169953227043152;train_accuracy is 0.9060500264167786;test_loss is 0.2917846441268921;test_accuracy is 0.9154999852180481
epoch2 train_loss is 0.3101600706577301;train_accuracy is 0.9079499840736389;test_loss is 0.2940264344215393;test_accuracy is 0.9137333035469055
epoch3 train_loss is 0.30038899183273315;train_accuracy is 0.9114000201225281;test_loss is 0.28789690136909485;test_accuracy is 0.9157249927520752
epoch4 train_loss is 0.29241883754730225;train_accuracy is 0.9130833148956299;test_loss is 0.2802391052246094;test_accuracy is 0.9186400175094604
epoch5 train_loss is 0.28577837347984314;train_accuracy is 0.9151166677474976;test_loss is 0.2763482332229614;test_accuracy is 0.9198833107948303
epoch6 train_loss is 0.27776893973350525;train_accuracy is 0.918666660785675;test_loss is 0.2713969349861145;test_accuracy is 0.9215571284294128
epoch7 train_loss is 0.2718273401260376;train_accuracy is 0.9201499819755554;test_loss is 0.2703363001346588;test_accuracy is 0.9218875169754028
epoch8 train_loss is 0.26651278138160706;train_accuracy is 0.9215333461761475;test_loss is 0.27081072330474854;test_accuracy is 0.9211888909339905
epoch9 train_loss is 0.2612370252609253;train_accuracy is 0.9223999977111816;test_loss is 0.26694610714912415;test_accuracy is 0.922569990158081
相关文章
|
9天前
|
机器学习/深度学习 PyTorch 算法框架/工具
目标检测实战(一):CIFAR10结合神经网络加载、训练、测试完整步骤
这篇文章介绍了如何使用PyTorch框架,结合CIFAR-10数据集,通过定义神经网络、损失函数和优化器,进行模型的训练和测试。
29 2
目标检测实战(一):CIFAR10结合神经网络加载、训练、测试完整步骤
|
9天前
|
机器学习/深度学习 数据可视化 计算机视觉
目标检测笔记(五):详细介绍并实现可视化深度学习中每层特征层的网络训练情况
这篇文章详细介绍了如何通过可视化深度学习中每层特征层来理解网络的内部运作,并使用ResNet系列网络作为例子,展示了如何在训练过程中加入代码来绘制和保存特征图。
28 1
目标检测笔记(五):详细介绍并实现可视化深度学习中每层特征层的网络训练情况
|
7天前
|
安全 定位技术 数据安全/隐私保护
|
4天前
|
存储 前端开发 JavaScript
链动模式融合排队免单:扩散用户裂变网络、提高复购
将链动2+1与排队免单结合的模式及链动3+1模式转化为可运行代码涉及多个技术领域,包括后端开发、前端开发、数据库设计等。本文提供了一个简化的技术框架,涵盖用户管理、订单处理、奖励计算、团队结构等核心功能,并提供了示例代码。同时,强调了安全性、测试与部署的重要性,以确保系统的稳定性和合规性。
|
9天前
|
Docker 容器
docker中创建自定义网络
【10月更文挑战第7天】
17 6
|
11天前
|
Docker 容器
docker中自定义网络
【10月更文挑战第5天】
12 3
|
9天前
|
机器学习/深度学习 数据采集 算法
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)
这篇博客文章介绍了如何使用包含多个网络和多种训练策略的框架来完成多目标分类任务,涵盖了从数据准备到训练、测试和部署的完整流程,并提供了相关代码和配置文件。
21 0
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)
|
11天前
|
机器学习/深度学习 算法 TensorFlow
深度学习笔记(五):学习率过大过小对于网络训练有何影响以及如何解决
学习率是深度学习中的关键超参数,它影响模型的训练进度和收敛性,过大或过小的学习率都会对网络训练产生负面影响,需要通过适当的设置和调整策略来优化。
87 0
深度学习笔记(五):学习率过大过小对于网络训练有何影响以及如何解决
|
11天前
|
安全 网络安全 数据安全/隐私保护
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
【10月更文挑战第6天】在数字化时代,网络安全和信息安全已成为我们生活中不可或缺的一部分。本文将探讨网络安全漏洞、加密技术和安全意识等方面的内容,以帮助读者更好地了解这些主题,并采取适当的措施保护自己的信息安全。我们将通过代码示例来演示一些常见的安全漏洞,并提供解决方案。最后,我们将强调培养良好的安全意识对于维护个人和组织的信息安全的重要性。
|
8天前
|
安全 网络安全 数据安全/隐私保护
网络安全与信息安全:守护数字世界的坚盾
在数字化浪潮中,网络安全已成为维系现代社会正常运转的关键。本文旨在探讨网络安全漏洞的成因、加密技术的应用及安全意识的提升,以期为广大用户和技术人员提供实用的知识分享。通过对这些方面的深入剖析,我们期望能够共同构建一个更加安全可靠的数字环境。

热门文章

最新文章