TF之GD:基于tensorflow框架搭建GD算法利用Fashion-MNIST数据集实现多分类预测(92%)

简介: TF之GD:基于tensorflow框架搭建GD算法利用Fashion-MNIST数据集实现多分类预测(92%)

输出结果

image.png



Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.

Extracting data/fashion\train-images-idx3-ubyte.gz

Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.

Extracting data/fashion\train-labels-idx1-ubyte.gz

Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.

Extracting data/fashion\t10k-images-idx3-ubyte.gz

Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.

Extracting data/fashion\t10k-labels-idx1-ubyte.gz

(55000, 784)

(55000, 10)

Epoch: 0,acc: 0.7965

Epoch: 1,acc: 0.8118

Epoch: 2,acc: 0.8743

Epoch: 3,acc: 0.8997

Epoch: 4,acc: 0.9058

Epoch: 5,acc: 0.9083

Epoch: 6,acc: 0.9102

Epoch: 7,acc: 0.9117

Epoch: 8,acc: 0.9137

Epoch: 9,acc: 0.9147

Epoch: 10,acc: 0.9158

Epoch: 11,acc: 0.9166

Epoch: 12,acc: 0.9186

Epoch: 13,acc: 0.9191

Epoch: 14,acc: 0.9187

Epoch: 15,acc: 0.9195

Epoch: 16,acc: 0.9206

Epoch: 17,acc: 0.9207

Epoch: 18,acc: 0.9216

Epoch: 19,acc: 0.9215

Epoch: 20,acc: 0.9218


实现代码

 

#TF之GD:基于tensorflow框架搭建GD算法利用Fashion-MNIST数据集实现多分类预测(92%)

import  tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

fashion = input_data.read_data_sets('data/fashion', one_hot=True)

print(fashion.train.images.shape)

print(fashion.train.labels.shape)

batch_size = 100

batch_num = fashion.train.num_examples // batch_size

#定义X,Y参数

x = tf.placeholder(tf.float32, shape=[None, 784])

y = tf.placeholder(tf.float32, shape=[None, 10])

#定义W,B参数

W = tf.Variable(tf.truncated_normal([784, 10], stddev= 0.1))

b = tf.Variable(tf.zeros([10]) + 0.1)

#预测结果

prediction = tf.nn.softmax(tf.matmul(x, W) + b)

#使用交叉熵计算loss

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=prediction, labels=y))

#定义优化器

train_step = tf.train.GradientDescentOptimizer(0.2).minimize(cross_entropy)

#判断预测结果是否正确

correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))

#计算准确率,将bool值转为float32

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

with tf.Session() as sess:

   sess.run(tf.global_variables_initializer())

   for epoch in range(21):

       for i in range(batch_num):

           batch_xs, batch_ys = fashion.train.next_batch(batch_size)

           sess.run(train_step, feed_dict={x: batch_xs, y:batch_ys})

       acc = sess.run(accuracy, feed_dict={x:fashion.test.images, y:fashion.test.labels})

       print('Epoch: '+str(epoch)+',acc: '+str(acc))



相关文章
|
4月前
|
数据采集 TensorFlow 算法框架/工具
【大作业-03】手把手教你用tensorflow2.3训练自己的分类数据集
本教程详细介绍了如何使用TensorFlow 2.3训练自定义图像分类数据集,涵盖数据集收集、整理、划分及模型训练与测试全过程。提供完整代码示例及图形界面应用开发指导,适合初学者快速上手。[教程链接](https://www.bilibili.com/video/BV1rX4y1A7N8/),配套视频更易理解。
102 0
【大作业-03】手把手教你用tensorflow2.3训练自己的分类数据集
|
1天前
|
机器学习/深度学习 PyTorch TensorFlow
深度学习工具和框架详细指南:PyTorch、TensorFlow、Keras
在深度学习的世界中,PyTorch、TensorFlow和Keras是最受欢迎的工具和框架,它们为研究者和开发者提供了强大且易于使用的接口。在本文中,我们将深入探索这三个框架,涵盖如何用它们实现经典深度学习模型,并通过代码实例详细讲解这些工具的使用方法。
|
3月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
利用Python和TensorFlow构建简单神经网络进行图像分类
利用Python和TensorFlow构建简单神经网络进行图像分类
91 3
|
6月前
|
机器学习/深度学习 人工智能 PyTorch
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
107 1
|
6月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
【Tensorflow+Keras】keras实现条件生成对抗网络DCGAN--以Minis和fashion_mnist数据集为例
如何使用TensorFlow和Keras实现条件生成对抗网络(CGAN)并以MNIST和Fashion MNIST数据集为例进行演示。
79 3
|
6月前
|
测试技术 数据库
探索JSF单元测试秘籍!如何让您的应用更稳固、更高效?揭秘成功背后的测试之道!
【8月更文挑战第31天】在 JavaServer Faces(JSF)应用开发中,确保代码质量和可维护性至关重要。本文详细介绍了如何通过单元测试实现这一目标。首先,阐述了单元测试的重要性及其对应用稳定性的影响;其次,提出了提高 JSF 应用可测试性的设计建议,如避免直接访问外部资源和使用依赖注入;最后,通过一个具体的 `UserBean` 示例,展示了如何利用 JUnit 和 Mockito 框架编写有效的单元测试。通过这些方法,不仅能够确保代码质量,还能提高开发效率和降低维护成本。
75 0
|
6月前
|
UED 开发者
哇塞!Uno Platform 数据绑定超全技巧大揭秘!从基础绑定到高级转换,优化性能让你的开发如虎添翼
【8月更文挑战第31天】在开发过程中,数据绑定是连接数据模型与用户界面的关键环节,可实现数据自动更新。Uno Platform 提供了简洁高效的数据绑定方式,使属性变化时 UI 自动同步更新。通过示例展示了基本绑定方法及使用 `Converter` 转换数据的高级技巧,如将年龄转换为格式化字符串。此外,还可利用 `BindingMode.OneTime` 提升性能。掌握这些技巧能显著提高开发效率并优化用户体验。
97 0
|
6月前
|
UED 存储 数据管理
深度解析 Uno Platform 离线状态处理技巧:从网络检测到本地存储同步,全方位提升跨平台应用在无网环境下的用户体验与数据管理策略
【8月更文挑战第31天】处理离线状态下的用户体验是现代应用开发的关键。本文通过在线笔记应用案例,介绍如何使用 Uno Platform 优雅地应对离线状态。首先,利用 `NetworkInformation` 类检测网络状态;其次,使用 SQLite 实现离线存储;然后,在网络恢复时同步数据;最后,通过 UI 反馈提升用户体验。
149 0
|
6月前
|
机器学习/深度学习 TensorFlow 数据处理
分布式训练在TensorFlow中的全面应用指南:掌握多机多卡配置与实践技巧,让大规模数据集训练变得轻而易举,大幅提升模型训练效率与性能
【8月更文挑战第31天】本文详细介绍了如何在Tensorflow中实现多机多卡的分布式训练,涵盖环境配置、模型定义、数据处理及训练执行等关键环节。通过具体示例代码,展示了使用`MultiWorkerMirroredStrategy`进行分布式训练的过程,帮助读者更好地应对大规模数据集与复杂模型带来的挑战,提升训练效率。
161 0
|
2月前
|
机器学习/深度学习 人工智能 算法
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
356 55

热门文章

最新文章