TensorFlow高层次机器学习API (tf.contrib.learn)

简介:

TensorFlow高层次机器学习API (tf.contrib.learn)

1.tf.contrib.learn.datasets.base.load_csv_with_header 加载csv格式数据

2.tf.contrib.learn.DNNClassifier 建立DNN模型(classifier)

3.classifer.fit 训练模型

4.classifier.evaluate 评价模型

5.classifier.predict 预测新样本

完整代码:

复制代码
复制代码
 1 from __future__ import absolute_import
 2 from __future__ import division  3 from __future__ import print_function  4  5 import tensorflow as tf  6 import numpy as np  7  8 # Data sets  9 IRIS_TRAINING = "iris_training.csv" 10 IRIS_TEST = "iris_test.csv" 11 12 # Load datasets. 13 training_set = tf.contrib.learn.datasets.base.load_csv_with_header( 14 filename=IRIS_TRAINING, 15 target_dtype=np.int, 16 features_dtype=np.float32) 17 test_set = tf.contrib.learn.datasets.base.load_csv_with_header( 18 filename=IRIS_TEST, 19 target_dtype=np.int, 20 features_dtype=np.float32) 21 22 # Specify that all features have real-value data 23 feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] 24 25 # Build 3 layer DNN with 10, 20, 10 units respectively. 26 classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns, 27 hidden_units=[10, 20, 10], 28 n_classes=3, 29 model_dir="/tmp/iris_model") 30 31 # Fit model. 32 classifier.fit(x=training_set.data, 33 y=training_set.target, 34 steps=2000) 35 36 # Evaluate accuracy. 37 accuracy_score = classifier.evaluate(x=test_set.data, 38 y=test_set.target)["accuracy"] 39 print('Accuracy: {0:f}'.format(accuracy_score)) 40 41 # Classify two new flower samples. 42 new_samples = np.array( 43 [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float) 44 y = list(classifier.predict(new_samples, as_iterable=True)) 45 print('Predictions: {}'.format(str(y)))
复制代码
复制代码

 结果:

Accuracy:0.966667


















本文转自张昺华-sky博客园博客,原文链接:http://www.cnblogs.com/bonelee/p/7903436.html,如需转载请自行联系原作者




相关文章
|
6月前
|
人工智能 物联网 API
又又又上新啦!魔搭免费模型推理API支持DeepSeek-R1,Qwen2.5-VL,Flux.1 dev及Lora等
又又又上新啦!魔搭免费模型推理API支持DeepSeek-R1,Qwen2.5-VL,Flux.1 dev及Lora等
334 7
|
9月前
|
人工智能 API 语音技术
开发者福利,魔搭推出免费模型推理API,注册就送每日2000次调用!
今天,魔搭社区开放了免费的开源模型推理API,仅需使用魔搭的SDK Token,就可以通过简单的API请求探索各种强大的开源模型的使用。
929 9
|
10月前
|
机器学习/深度学习 人工智能 TensorFlow
基于TensorFlow的深度学习模型训练与优化实战
基于TensorFlow的深度学习模型训练与优化实战
424 3
|
10月前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
362 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
10月前
|
机器学习/深度学习 TensorFlow API
机器学习实战:TensorFlow在图像识别中的应用探索
【10月更文挑战第28天】随着深度学习技术的发展,图像识别取得了显著进步。TensorFlow作为Google开源的机器学习框架,凭借其强大的功能和灵活的API,在图像识别任务中广泛应用。本文通过实战案例,探讨TensorFlow在图像识别中的优势与挑战,展示如何使用TensorFlow构建和训练卷积神经网络(CNN),并评估模型的性能。尽管面临学习曲线和资源消耗等挑战,TensorFlow仍展现出广阔的应用前景。
287 5
|
12月前
|
机器学习/深度学习 算法 TensorFlow
交通标志识别系统Python+卷积神经网络算法+深度学习人工智能+TensorFlow模型训练+计算机课设项目+Django网页界面
交通标志识别系统。本系统使用Python作为主要编程语言,在交通标志图像识别功能实现中,基于TensorFlow搭建卷积神经网络算法模型,通过对收集到的58种常见的交通标志图像作为数据集,进行迭代训练最后得到一个识别精度较高的模型文件,然后保存为本地的h5格式文件。再使用Django开发Web网页端操作界面,实现用户上传一张交通标志图片,识别其名称。
447 6
交通标志识别系统Python+卷积神经网络算法+深度学习人工智能+TensorFlow模型训练+计算机课设项目+Django网页界面
|
11月前
|
机器学习/深度学习 人工智能 算法
【玉米病害识别】Python+卷积神经网络算法+人工智能+深度学习+计算机课设项目+TensorFlow+模型训练
玉米病害识别系统,本系统使用Python作为主要开发语言,通过收集了8种常见的玉米叶部病害图片数据集('矮花叶病', '健康', '灰斑病一般', '灰斑病严重', '锈病一般', '锈病严重', '叶斑病一般', '叶斑病严重'),然后基于TensorFlow搭建卷积神经网络算法模型,通过对数据集进行多轮迭代训练,最后得到一个识别精度较高的模型文件。再使用Django搭建Web网页操作平台,实现用户上传一张玉米病害图片识别其名称。
177 0
【玉米病害识别】Python+卷积神经网络算法+人工智能+深度学习+计算机课设项目+TensorFlow+模型训练
|
11月前
|
机器学习/深度学习 算法 API
机器学习入门(五):KNN概述 | K 近邻算法 API,K值选择问题
机器学习入门(五):KNN概述 | K 近邻算法 API,K值选择问题
|
11月前
|
机器学习/深度学习 算法 数据可视化
【机器学习】决策树------迅速了解其基本思想,Sklearn的决策树API及构建决策树的步骤!!!
【机器学习】决策树------迅速了解其基本思想,Sklearn的决策树API及构建决策树的步骤!!!
|
机器学习/深度学习 存储 搜索推荐
利用机器学习算法改善电商推荐系统的效率
电商行业日益竞争激烈,提升用户体验成为关键。本文将探讨如何利用机器学习算法优化电商推荐系统,通过分析用户行为数据和商品信息,实现个性化推荐,从而提高推荐效率和准确性。
428 14

热门文章

最新文章