tebsorflow2.0 eager模式与自定义训练网络(下)

简介: 对比tensorflow1.x版本静态图模式,tensorflow2.x推荐使用的是eager模式,即动态计算模式,它的特点是运算可以立即得到结果。我们可以通过tf.executing_eagerly()来判断是不是eager模式,如果返回的为True,使用的则为eager模式。首先我们简答介绍一下在eager模式下的计算。

3. 使用手写数据集自定义网络

3.1 数据的预处理

这一部分主要是数据加载,维度扩充和归一化处理。

(train_image,train_labels),(test_image,test_labels) = tf.keras.datasets.mnist.load_data() 
#扩充维度,增加通道项
train_image = tf.expand_dims(train_image,-1)
print(train_image.shape)
test_image = tf.expand_dims(test_image,-1)
print(train_image.shape)
#对图像改变数据类型,归一化
train_image = tf.cast(train_image/255,tf.float32)
train_labels = tf.cast(train_labels,tf.int64)
test_image = tf.cast(test_image/255,tf.float32)
test_labels = tf.cast(test_labels,tf.int64)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
(60000, 28, 28, 1)
(60000, 28, 28, 1)

3.2 数据批量化和网络构建

将数据批量化,batch_size = 32

dataset = tf.data.Dataset.from_tensor_slices((train_image,train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_image,test_labels))
dataset = dataset.shuffle(60000).batch(32)
test_dataset = test_dataset.batch(32)
print(dataset)
<BatchDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int64)>

构建网络

model = tf.keras.Sequential([
     tf.keras.layers.Conv2D(16,[3,3],activation="relu",input_shape=(None,None,1)),#任意大小的channel都能输入进来
     tf.keras.layers.Conv2D(32,[3,3],activation="relu"),
     tf.keras.layers.GlobalAveragePooling2D(),
     tf.keras.layers.Dense(10),
     ]
)
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, None, None, 16)    160       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, None, None, 32)    4640      
_________________________________________________________________
global_average_pooling2d (Gl (None, 32)                0         
_________________________________________________________________
dense (Dense)                (None, 10)                330       
=================================================================
Total params: 5,130
Trainable params: 5,130
Non-trainable params: 0
_________________________________________________________________

我们可以利用model.trainable_variables利用和查看过滤器的变量。

optimizer = tf.keras.optimizers.Adam()
#自定义损失,Sparse是可调用的对象
loss_fuc = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)# 这是可调用的方法,因为我们没有加入激活函数,所以from_logits=true
feature,label = next(iter(dataset))# 可以封装成迭代器直接调用
def loss(model,x,y):
  y_ = model(x)
  return loss_fuc(y,y_)
loss(model,feature,label)
<tf.Tensor: shape=(), dtype=float32, numpy=2.3087442>

接下来我们定义损失,准确率并求取梯度。

train_loss = tf.keras.metrics.Mean("train_loss")
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy("train_accuracy") 
#求梯度 
def train_step(model,image,labels):
  with tf.GradientTape() as t:
    pred = model(image)
    loss_step = loss_fuc(labels,pred)
  grads = t.gradient(loss_step,model.trainable_variables)
  optimizer.apply_gradients(zip(grads,model.trainable_variables))
  train_loss(loss_step)
  train_accuracy(labels,pred)

3.3 训练预测

def train():
  for epoch in range(10):
    for (batch,(image,labels)) in enumerate(dataset):
      #进行异步训连
      train_step(model,image,labels)
    print("epoch{} loss is {};accuracy is {}".format(epoch,
                               train_loss.result(),
                               train_accuracy.result()))
    train_loss.reset_states()
    train_accuracy.reset_states()
train()
epoch0 loss is 0.47583284974098206;accuracy is 0.8527083396911621
epoch1 loss is 0.45875340700149536;accuracy is 0.862500011920929
epoch2 loss is 0.44017791748046875;accuracy is 0.8687833547592163
epoch3 loss is 0.4235962927341461;accuracy is 0.8733333349227905
epoch4 loss is 0.4048921465873718;accuracy is 0.8791000247001648
epoch5 loss is 0.3935568332672119;accuracy is 0.8831833600997925
epoch6 loss is 0.38044092059135437;accuracy is 0.8866000175476074
epoch7 loss is 0.370032399892807;accuracy is 0.8890500068664551
epoch8 loss is 0.3582034409046173;accuracy is 0.8931166529655457
epoch9 loss is 0.34430235624313354;accuracy is 0.8981166481971741

我们也可以在训练中打印出test的数变化情况

train_loss = tf.keras.metrics.Mean("train_loss")
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy("train_accuracy") 
test_loss = tf.keras.metrics.Mean("test_loss")
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy("test_accuracy") 
def test_step(model,image,labels):
  pred = model(image)
  loss_step = loss_fuc(labels,pred)
  test_loss(loss_step)
  test_accuracy(labels,pred)
def train():
  for epoch in range(10):
    for (batch,(image,labels)) in enumerate(dataset):
      #进行异步训连
      train_step(model,image,labels)
    for (batch,(image,labels)) in enumerate(test_dataset):
      test_step(model,image,labels)
    print("epoch{} train_loss is {};train_accuracy is {};test_loss is {};test_accuracy is {}".format(epoch,
                               train_loss.result(),
                               train_accuracy.result(),
                               test_loss.result(),
                               test_accuracy.result()
                               ))
    train_loss.reset_states()
    train_accuracy.reset_states()
train()
epoch0 train_loss is 0.3299615681171417;train_accuracy is 0.9027583599090576;test_loss is 0.3066971004009247;test_accuracy is 0.907800018787384
epoch1 train_loss is 0.3169953227043152;train_accuracy is 0.9060500264167786;test_loss is 0.2917846441268921;test_accuracy is 0.9154999852180481
epoch2 train_loss is 0.3101600706577301;train_accuracy is 0.9079499840736389;test_loss is 0.2940264344215393;test_accuracy is 0.9137333035469055
epoch3 train_loss is 0.30038899183273315;train_accuracy is 0.9114000201225281;test_loss is 0.28789690136909485;test_accuracy is 0.9157249927520752
epoch4 train_loss is 0.29241883754730225;train_accuracy is 0.9130833148956299;test_loss is 0.2802391052246094;test_accuracy is 0.9186400175094604
epoch5 train_loss is 0.28577837347984314;train_accuracy is 0.9151166677474976;test_loss is 0.2763482332229614;test_accuracy is 0.9198833107948303
epoch6 train_loss is 0.27776893973350525;train_accuracy is 0.918666660785675;test_loss is 0.2713969349861145;test_accuracy is 0.9215571284294128
epoch7 train_loss is 0.2718273401260376;train_accuracy is 0.9201499819755554;test_loss is 0.2703363001346588;test_accuracy is 0.9218875169754028
epoch8 train_loss is 0.26651278138160706;train_accuracy is 0.9215333461761475;test_loss is 0.27081072330474854;test_accuracy is 0.9211888909339905
epoch9 train_loss is 0.2612370252609253;train_accuracy is 0.9223999977111816;test_loss is 0.26694610714912415;test_accuracy is 0.922569990158081
相关文章
|
1天前
|
NoSQL 关系型数据库 MySQL
《docker高级篇(大厂进阶):4.Docker网络》包括:是什么、常用基本命令、能干嘛、网络模式、docker平台架构图解
《docker高级篇(大厂进阶):4.Docker网络》包括:是什么、常用基本命令、能干嘛、网络模式、docker平台架构图解
76 56
《docker高级篇(大厂进阶):4.Docker网络》包括:是什么、常用基本命令、能干嘛、网络模式、docker平台架构图解
|
13天前
|
安全 Docker 容器
docker的默认网络模式有哪些
Docker 默认网络模式包括:1) bridge:默认模式,各容器分配独立IP,可通过名称或IP通信;2) host:容器与宿主机共享网络命名空间,性能最优但有安全风险;3) none:容器隔离无网络配置,适用于仅需本地通信的场景。
26 6
|
21天前
|
机器学习/深度学习 自然语言处理 语音技术
Python在深度学习领域的应用,重点讲解了神经网络的基础概念、基本结构、训练过程及优化技巧
本文介绍了Python在深度学习领域的应用,重点讲解了神经网络的基础概念、基本结构、训练过程及优化技巧,并通过TensorFlow和PyTorch等库展示了实现神经网络的具体示例,涵盖图像识别、语音识别等多个应用场景。
44 8
|
1月前
|
Docker 容器
【赵渝强老师】Docker的None网络模式
Docker容器在网络方面实现了逻辑隔离,提供了四种网络模式:bridge、container、host和none。其中,none模式下容器具有独立的网络命名空间,但不包含任何网络配置,仅能通过Local Loopback网卡(localhost或127.0.0.1)进行通信。适用于不希望容器接收任何网络流量或运行无需网络连接的特殊服务。
|
1月前
|
Docker 容器
【赵渝强老师】Docker的Host网络模式
Docker容器在网络环境中是隔离的,可通过配置不同网络模式(如bridge、container、host和none)实现容器间或与宿主机的网络通信。其中,host模式使容器与宿主机共享同一网络命名空间,提高性能但牺牲了网络隔离性。
|
1月前
|
Kubernetes Docker 容器
【赵渝强老师】Docker的Container网络模式
Docker容器在网络环境中彼此隔离,但可通过配置不同网络模式实现容器间通信。其中,container模式使容器共享同一网络命名空间,通过localhost或127.0.0.1互相访问,提高传输效率。本文介绍了container模式的特点及具体示例。
|
1天前
|
SQL 安全 网络安全
网络安全与信息安全:知识分享####
【10月更文挑战第21天】 随着数字化时代的快速发展,网络安全和信息安全已成为个人和企业不可忽视的关键问题。本文将探讨网络安全漏洞、加密技术以及安全意识的重要性,并提供一些实用的建议,帮助读者提高自身的网络安全防护能力。 ####
34 17
|
12天前
|
存储 SQL 安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
随着互联网的普及,网络安全问题日益突出。本文将介绍网络安全的重要性,分析常见的网络安全漏洞及其危害,探讨加密技术在保障网络安全中的作用,并强调提高安全意识的必要性。通过本文的学习,读者将了解网络安全的基本概念和应对策略,提升个人和组织的网络安全防护能力。
|
13天前
|
SQL 安全 网络安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
随着互联网的普及,网络安全问题日益突出。本文将从网络安全漏洞、加密技术和安全意识三个方面进行探讨,旨在提高读者对网络安全的认识和防范能力。通过分析常见的网络安全漏洞,介绍加密技术的基本原理和应用,以及强调安全意识的重要性,帮助读者更好地保护自己的网络信息安全。
36 10
|
14天前
|
SQL 安全 网络安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
在数字化时代,网络安全和信息安全已成为我们生活中不可或缺的一部分。本文将介绍网络安全漏洞、加密技术和安全意识等方面的内容,并提供一些实用的代码示例。通过阅读本文,您将了解到如何保护自己的网络安全,以及如何提高自己的信息安全意识。
43 10