TF之GD:基于tensorflow框架搭建GD算法利用Fashion-MNIST数据集实现多分类预测(92%)

简介: TF之GD:基于tensorflow框架搭建GD算法利用Fashion-MNIST数据集实现多分类预测(92%)

输出结果

image.png



Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.

Extracting data/fashion\train-images-idx3-ubyte.gz

Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.

Extracting data/fashion\train-labels-idx1-ubyte.gz

Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.

Extracting data/fashion\t10k-images-idx3-ubyte.gz

Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.

Extracting data/fashion\t10k-labels-idx1-ubyte.gz

(55000, 784)

(55000, 10)

Epoch: 0,acc: 0.7965

Epoch: 1,acc: 0.8118

Epoch: 2,acc: 0.8743

Epoch: 3,acc: 0.8997

Epoch: 4,acc: 0.9058

Epoch: 5,acc: 0.9083

Epoch: 6,acc: 0.9102

Epoch: 7,acc: 0.9117

Epoch: 8,acc: 0.9137

Epoch: 9,acc: 0.9147

Epoch: 10,acc: 0.9158

Epoch: 11,acc: 0.9166

Epoch: 12,acc: 0.9186

Epoch: 13,acc: 0.9191

Epoch: 14,acc: 0.9187

Epoch: 15,acc: 0.9195

Epoch: 16,acc: 0.9206

Epoch: 17,acc: 0.9207

Epoch: 18,acc: 0.9216

Epoch: 19,acc: 0.9215

Epoch: 20,acc: 0.9218


实现代码

 

#TF之GD:基于tensorflow框架搭建GD算法利用Fashion-MNIST数据集实现多分类预测(92%)

import  tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

fashion = input_data.read_data_sets('data/fashion', one_hot=True)

print(fashion.train.images.shape)

print(fashion.train.labels.shape)

batch_size = 100

batch_num = fashion.train.num_examples // batch_size

#定义X,Y参数

x = tf.placeholder(tf.float32, shape=[None, 784])

y = tf.placeholder(tf.float32, shape=[None, 10])

#定义W,B参数

W = tf.Variable(tf.truncated_normal([784, 10], stddev= 0.1))

b = tf.Variable(tf.zeros([10]) + 0.1)

#预测结果

prediction = tf.nn.softmax(tf.matmul(x, W) + b)

#使用交叉熵计算loss

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=prediction, labels=y))

#定义优化器

train_step = tf.train.GradientDescentOptimizer(0.2).minimize(cross_entropy)

#判断预测结果是否正确

correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))

#计算准确率,将bool值转为float32

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

with tf.Session() as sess:

   sess.run(tf.global_variables_initializer())

   for epoch in range(21):

       for i in range(batch_num):

           batch_xs, batch_ys = fashion.train.next_batch(batch_size)

           sess.run(train_step, feed_dict={x: batch_xs, y:batch_ys})

       acc = sess.run(accuracy, feed_dict={x:fashion.test.images, y:fashion.test.labels})

       print('Epoch: '+str(epoch)+',acc: '+str(acc))



相关文章
|
5天前
|
机器学习/深度学习 人工智能 算法
海洋生物识别系统+图像识别+Python+人工智能课设+深度学习+卷积神经网络算法+TensorFlow
海洋生物识别系统。以Python作为主要编程语言,通过TensorFlow搭建ResNet50卷积神经网络算法,通过对22种常见的海洋生物('蛤蜊', '珊瑚', '螃蟹', '海豚', '鳗鱼', '水母', '龙虾', '海蛞蝓', '章鱼', '水獭', '企鹅', '河豚', '魔鬼鱼', '海胆', '海马', '海豹', '鲨鱼', '虾', '鱿鱼', '海星', '海龟', '鲸鱼')数据集进行训练,得到一个识别精度较高的模型文件,然后使用Django开发一个Web网页平台操作界面,实现用户上传一张海洋生物图片识别其名称。
84 7
海洋生物识别系统+图像识别+Python+人工智能课设+深度学习+卷积神经网络算法+TensorFlow
|
5天前
|
机器学习/深度学习 人工智能 算法
【昆虫识别系统】图像识别Python+卷积神经网络算法+人工智能+深度学习+机器学习+TensorFlow+ResNet50
昆虫识别系统,使用Python作为主要开发语言。通过TensorFlow搭建ResNet50卷积神经网络算法(CNN)模型。通过对10种常见的昆虫图片数据集('蜜蜂', '甲虫', '蝴蝶', '蝉', '蜻蜓', '蚱蜢', '蛾', '蝎子', '蜗牛', '蜘蛛')进行训练,得到一个识别精度较高的H5格式模型文件,然后使用Django搭建Web网页端可视化操作界面,实现用户上传一张昆虫图片识别其名称。
118 7
【昆虫识别系统】图像识别Python+卷积神经网络算法+人工智能+深度学习+机器学习+TensorFlow+ResNet50
|
6天前
|
机器学习/深度学习 人工智能 算法
【球类识别系统】图像识别Python+卷积神经网络算法+人工智能+深度学习+TensorFlow
球类识别系统,本系统使用Python作为主要编程语言,基于TensorFlow搭建ResNet50卷积神经网络算法模型,通过收集 '美式足球', '棒球', '篮球', '台球', '保龄球', '板球', '足球', '高尔夫球', '曲棍球', '冰球', '橄榄球', '羽毛球', '乒乓球', '网球', '排球'等15种常见的球类图像作为数据集,然后进行训练,最终得到一个识别精度较高的模型文件。再使用Django开发Web网页端可视化界面平台,实现用户上传一张球类图片识别其名称。
101 7
【球类识别系统】图像识别Python+卷积神经网络算法+人工智能+深度学习+TensorFlow
|
6天前
|
机器学习/深度学习 算法
基于鲸鱼优化的knn分类特征选择算法matlab仿真
**基于WOA的KNN特征选择算法摘要** 该研究提出了一种融合鲸鱼优化算法(WOA)与K近邻(KNN)分类器的特征选择方法,旨在提升KNN的分类精度。在MATLAB2022a中实现,WOA负责优化特征子集,通过模拟鲸鱼捕食行为的螺旋式和包围策略搜索最佳特征。KNN则用于评估特征子集的性能。算法流程包括WOA参数初始化、特征二进制编码、适应度函数定义(以分类准确率为基准)、WOA迭代搜索及最优解输出。该方法有效地结合了启发式搜索与机器学习,优化特征选择,提高分类性能。
|
15天前
|
算法
【经典LeetCode算法题目专栏分类】【第10期】排序问题、股票问题与TOP K问题:翻转对、买卖股票最佳时机、数组中第K个最大/最小元素
【经典LeetCode算法题目专栏分类】【第10期】排序问题、股票问题与TOP K问题:翻转对、买卖股票最佳时机、数组中第K个最大/最小元素
|
1天前
|
存储 算法 安全
加密算法概述:分类与常见算法
加密算法概述:分类与常见算法
|
8天前
|
算法
基于蝗虫优化的KNN分类特征选择算法的matlab仿真
摘要: - 功能:使用蝗虫优化算法增强KNN分类器的特征选择,提高分类准确性 - 软件版本:MATLAB2022a - 核心算法:通过GOA选择KNN的最优特征以改善性能 - 算法原理: - KNN基于最近邻原则进行分类 - 特征选择能去除冗余,提高效率 - GOA模仿蝗虫行为寻找最佳特征子集,以最大化KNN的验证集准确率 - 运行流程:初始化、评估、更新,直到达到停止标准,输出最佳特征组合
|
9天前
|
机器学习/深度学习 算法 数据挖掘
【机器学习】Voting集成学习算法:分类任务中的新利器
【机器学习】Voting集成学习算法:分类任务中的新利器
15 0
|
9天前
|
机器学习/深度学习 算法 数据可视化
【机器学习】分类与预测算法的评价与优化
【机器学习】分类与预测算法的评价与优化
20 0
|
15天前
|
算法
【经典LeetCode算法题目专栏分类】【第11期】递归问题:字母大小写全排列、括号生成
【经典LeetCode算法题目专栏分类】【第11期】递归问题:字母大小写全排列、括号生成