TF之GD:基于tensorflow框架搭建GD算法利用Fashion-MNIST数据集实现多分类预测(92%)

简介: TF之GD:基于tensorflow框架搭建GD算法利用Fashion-MNIST数据集实现多分类预测(92%)

输出结果

image.png



Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.

Extracting data/fashion\train-images-idx3-ubyte.gz

Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.

Extracting data/fashion\train-labels-idx1-ubyte.gz

Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.

Extracting data/fashion\t10k-images-idx3-ubyte.gz

Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.

Extracting data/fashion\t10k-labels-idx1-ubyte.gz

(55000, 784)

(55000, 10)

Epoch: 0,acc: 0.7965

Epoch: 1,acc: 0.8118

Epoch: 2,acc: 0.8743

Epoch: 3,acc: 0.8997

Epoch: 4,acc: 0.9058

Epoch: 5,acc: 0.9083

Epoch: 6,acc: 0.9102

Epoch: 7,acc: 0.9117

Epoch: 8,acc: 0.9137

Epoch: 9,acc: 0.9147

Epoch: 10,acc: 0.9158

Epoch: 11,acc: 0.9166

Epoch: 12,acc: 0.9186

Epoch: 13,acc: 0.9191

Epoch: 14,acc: 0.9187

Epoch: 15,acc: 0.9195

Epoch: 16,acc: 0.9206

Epoch: 17,acc: 0.9207

Epoch: 18,acc: 0.9216

Epoch: 19,acc: 0.9215

Epoch: 20,acc: 0.9218


实现代码

 

#TF之GD:基于tensorflow框架搭建GD算法利用Fashion-MNIST数据集实现多分类预测(92%)

import  tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

fashion = input_data.read_data_sets('data/fashion', one_hot=True)

print(fashion.train.images.shape)

print(fashion.train.labels.shape)

batch_size = 100

batch_num = fashion.train.num_examples // batch_size

#定义X,Y参数

x = tf.placeholder(tf.float32, shape=[None, 784])

y = tf.placeholder(tf.float32, shape=[None, 10])

#定义W,B参数

W = tf.Variable(tf.truncated_normal([784, 10], stddev= 0.1))

b = tf.Variable(tf.zeros([10]) + 0.1)

#预测结果

prediction = tf.nn.softmax(tf.matmul(x, W) + b)

#使用交叉熵计算loss

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=prediction, labels=y))

#定义优化器

train_step = tf.train.GradientDescentOptimizer(0.2).minimize(cross_entropy)

#判断预测结果是否正确

correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))

#计算准确率,将bool值转为float32

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

with tf.Session() as sess:

   sess.run(tf.global_variables_initializer())

   for epoch in range(21):

       for i in range(batch_num):

           batch_xs, batch_ys = fashion.train.next_batch(batch_size)

           sess.run(train_step, feed_dict={x: batch_xs, y:batch_ys})

       acc = sess.run(accuracy, feed_dict={x:fashion.test.images, y:fashion.test.labels})

       print('Epoch: '+str(epoch)+',acc: '+str(acc))



相关文章
|
5月前
|
机器学习/深度学习 Dragonfly 人工智能
基于蜻蜓算法优化支持向量机(DA-SVM)的数据多特征分类预测研究(Matlab代码实现)
基于蜻蜓算法优化支持向量机(DA-SVM)的数据多特征分类预测研究(Matlab代码实现)
141 1
|
4月前
|
机器学习/深度学习 算法 调度
14种智能算法优化BP神经网络(14种方法)实现数据预测分类研究(Matlab代码实现)
14种智能算法优化BP神经网络(14种方法)实现数据预测分类研究(Matlab代码实现)
426 0
|
3月前
|
机器学习/深度学习 PyTorch TensorFlow
66_框架选择:PyTorch vs TensorFlow
在2025年的大语言模型(LLM)开发领域,框架选择已成为项目成功的关键决定因素。随着模型规模的不断扩大和应用场景的日益复杂,选择一个既适合研究探索又能支持高效部署的框架变得尤为重要。PyTorch和TensorFlow作为目前市场上最主流的两大深度学习框架,各自拥有独特的优势和生态系统,也因此成为开发者面临的经典选择难题。
|
4月前
|
机器学习/深度学习 算法 PyTorch
【Pytorch框架搭建神经网络】基于DQN算法、优先级采样的DQN算法、DQN + 人工势场的避障控制研究(Python代码实现)
【Pytorch框架搭建神经网络】基于DQN算法、优先级采样的DQN算法、DQN + 人工势场的避障控制研究(Python代码实现)
124 1
|
4月前
|
机器学习/深度学习 算法 PyTorch
【DQN实现避障控制】使用Pytorch框架搭建神经网络,基于DQN算法、优先级采样的DQN算法、DQN + 人工势场实现避障控制研究(Matlab、Python实现)
【DQN实现避障控制】使用Pytorch框架搭建神经网络,基于DQN算法、优先级采样的DQN算法、DQN + 人工势场实现避障控制研究(Matlab、Python实现)
207 0
|
5月前
|
机器学习/深度学习 传感器 数据采集
【23年新算法】基于鱼鹰算法OOA-Transformer-BiLSTM多特征分类预测附Matlab代码 (多输入单输出)(Matlab代码实现)
【23年新算法】基于鱼鹰算法OOA-Transformer-BiLSTM多特征分类预测附Matlab代码 (多输入单输出)(Matlab代码实现)
386 0
|
6月前
|
机器学习/深度学习 人工智能 算法
AP聚类算法实现三维数据点分类
AP聚类算法实现三维数据点分类
212 0
|
机器学习/深度学习 人工智能 算法
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
1097 55
|
机器学习/深度学习 数据采集 数据可视化
TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤
本文介绍了 TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤,包括数据准备、模型定义、损失函数与优化器选择、模型训练与评估、模型保存与部署,并展示了构建全连接神经网络的具体示例。此外,还探讨了 TensorFlow 的高级特性,如自动微分、模型可视化和分布式训练,以及其在未来的发展前景。
1051 5
|
机器学习/深度学习 人工智能 TensorFlow
基于TensorFlow的深度学习模型训练与优化实战
基于TensorFlow的深度学习模型训练与优化实战
595 3