鸢尾花数据集分类问题(2)

简介: 鸢尾花数据集分类问题

鸢尾花数据集分类问题(1)https://developer.aliyun.com/article/1540968

2.数据集乱序

tf.random.set_seed(116)     # 设计随机种子,以生成同样的标签序列,使x_data和y_data一一对应
x_data = tf.random.shuffle(x_data,116)   # 数据集乱序
y_data = tf.random.shuffle(y_data,116)
print(x_data)
print(y_data)
tf.Tensor(
[[5.6 3.  4.5 1.5]
 [4.9 2.4 3.3 1. ]
 [6.7 3.  5.2 2.3]
 [4.6 3.2 1.4 0.2]
 [5.1 3.5 1.4 0.2]
 [5.4 3.7 1.5 0.2]
 [6.8 3.2 5.9 2.3]
 [5.3 3.7 1.5 0.2]
 [4.6 3.4 1.4 0.3]
 [4.8 3.  1.4 0.1]
 [6.5 3.  5.2 2. ]
 [4.9 2.5 4.5 1.7]
 [6.8 3.  5.5 2.1]
 [4.8 3.  1.4 0.3]
 [5.7 2.5 5.  2. ]
 [7.9 3.8 6.4 2. ]
 [5.  3.4 1.6 0.4]
 [5.9 3.2 4.8 1.8]
 [6.  2.2 5.  1.5]
 [7.7 2.8 6.7 2. ]
 [5.5 2.3 4.  1.3]
 [4.9 3.  1.4 0.2]
 [5.8 2.8 5.1 2.4]
 [5.4 3.4 1.5 0.4]
 [6.4 2.8 5.6 2.2]
 [5.8 2.6 4.  1.2]
 [5.6 2.7 4.2 1.3]
 [4.9 3.1 1.5 0.2]
 [6.7 3.1 4.7 1.5]
 [5.5 4.2 1.4 0.2]
 [7.2 3.  5.8 1.6]
 [6.3 2.9 5.6 1.8]
 [6.3 2.5 5.  1.9]
 [6.9 3.1 5.1 2.3]
 [6.4 3.2 5.3 2.3]
 [6.7 3.1 4.4 1.4]
 [5.  3.2 1.2 0.2]
 [6.3 3.3 6.  2.5]
 [6.  2.2 4.  1. ]
 [6.3 2.5 4.9 1.5]
 [6.5 3.  5.8 2.2]
 [7.2 3.6 6.1 2.5]
 [6.7 3.  5.  1.7]
 [6.  3.  4.8 1.8]
 [5.7 2.8 4.1 1.3]
 [5.7 4.4 1.5 0.4]
 [5.4 3.9 1.7 0.4]
 [5.6 2.5 3.9 1.1]
 [7.2 3.2 6.  1.8]
 [4.7 3.2 1.6 0.2]
 [5.1 3.4 1.5 0.2]
 [5.  3.6 1.4 0.2]
 [5.5 2.4 3.8 1.1]
 [5.7 2.8 4.5 1.3]
 [5.2 3.4 1.4 0.2]
 [7.6 3.  6.6 2.1]
 [6.3 3.3 4.7 1.6]
 [5.  2.3 3.3 1. ]
 [5.5 2.6 4.4 1.2]
 [4.8 3.4 1.9 0.2]
 [5.5 3.5 1.3 0.2]
 [5.1 3.5 1.4 0.3]
 [5.8 2.7 5.1 1.9]
 [4.6 3.6 1.  0.2]
 [5.2 2.7 3.9 1.4]
 [5.7 3.  4.2 1.2]
 [6.3 2.7 4.9 1.8]
 [6.  2.9 4.5 1.5]
 [6.7 3.3 5.7 2.5]
 [6.7 3.3 5.7 2.1]
 [4.6 3.1 1.5 0.2]
 [6.9 3.1 5.4 2.1]
 [6.2 2.8 4.8 1.8]
 [5.5 2.5 4.  1.3]
 [6.1 2.9 4.7 1.4]
 [5.7 3.8 1.7 0.3]
 [5.6 2.8 4.9 2. ]
 [6.6 3.  4.4 1.4]
 [4.7 3.2 1.3 0.2]
 [6.6 2.9 4.6 1.3]
 [6.1 3.  4.9 1.8]
 [6.3 2.8 5.1 1.5]
 [6.5 2.8 4.6 1.5]
 [5.1 3.7 1.5 0.4]
 [7.4 2.8 6.1 1.9]
 [4.9 3.1 1.5 0.1]
 [4.8 3.1 1.6 0.2]
 [5.5 2.4 3.7 1. ]
 [5.2 4.1 1.5 0.1]
 [5.4 3.  4.5 1.5]
 [6.7 3.1 5.6 2.4]
 [5.1 3.8 1.5 0.3]
 [4.4 3.2 1.3 0.2]
 [7.1 3.  5.9 2.1]
 [5.  3.3 1.4 0.2]
 [5.  3.5 1.6 0.6]
 [6.4 3.2 4.5 1.5]
 [6.1 3.  4.6 1.4]
 [6.3 2.3 4.4 1.3]
 [6.5 3.  5.5 1.8]
 [7.7 3.8 6.7 2.2]
 [6.3 3.4 5.6 2.4]
 [6.4 2.9 4.3 1.3]
 [5.  3.  1.6 0.2]
 [5.8 2.7 4.1 1. ]
 [5.8 2.7 3.9 1.2]
 [4.8 3.4 1.6 0.2]
 [6.9 3.1 4.9 1.5]
 [5.8 4.  1.2 0.2]
 [6.  3.4 4.5 1.6]
 [5.4 3.9 1.3 0.4]
 [6.1 2.8 4.  1.3]
 [5.4 3.4 1.7 0.2]
 [7.7 3.  6.1 2.3]
 [5.1 2.5 3.  1.1]
 [5.6 3.  4.1 1.3]
 [6.1 2.8 4.7 1.2]
 [6.2 3.4 5.4 2.3]
 [6.4 2.8 5.6 2.1]
 [5.7 2.6 3.5 1. ]
 [5.1 3.8 1.9 0.4]
 [5.2 3.5 1.5 0.2]
 [6.9 3.2 5.7 2.3]
 [4.5 2.3 1.3 0.3]
 [4.4 3.  1.3 0.2]
 [6.5 3.2 5.1 2. ]
 [4.3 3.  1.1 0.1]
 [5.7 2.9 4.2 1.3]
 [5.8 2.7 5.1 1.9]
 [6.4 3.1 5.5 1.8]
 [6.2 2.9 4.3 1.3]
 [5.6 2.9 3.6 1.3]
 [4.9 3.6 1.4 0.1]
 [6.2 2.2 4.5 1.5]
 [4.4 2.9 1.4 0.2]
 [5.  3.5 1.3 0.3]
 [5.9 3.  5.1 1.8]
 [6.1 2.6 5.6 1.4]
 [5.1 3.8 1.6 0.2]
 [6.8 2.8 4.8 1.4]
 [5.9 3.  4.2 1.5]
 [6.7 2.5 5.8 1.8]
 [7.  3.2 4.7 1.4]
 [7.7 2.6 6.9 2.3]
 [6.  2.7 5.1 1.6]
 [5.1 3.3 1.7 0.5]
 [5.  2.  3.5 1. ]
 [6.4 2.7 5.3 1.9]
 [5.  3.4 1.5 0.2]
 [7.3 2.9 6.3 1.8]], shape=(150, 4), dtype=float64)
tf.Tensor(
[1 1 2 0 0 0 2 0 0 0 2 2 2 0 2 2 0 1 2 2 1 0 2 0 2 1 1 0 1 0 2 2 2 2 2 1 0
 2 1 1 2 2 1 2 1 0 0 1 2 0 0 0 1 1 0 2 1 1 1 0 0 0 2 0 1 1 2 1 2 2 0 2 2 1
 1 0 2 1 0 1 2 2 1 0 2 0 0 1 0 1 2 0 0 2 0 0 1 1 1 2 2 2 1 0 1 1 0 1 0 1 0
 1 0 2 1 1 1 2 2 1 0 0 2 0 0 2 0 1 2 2 1 1 0 1 0 0 2 2 0 1 1 2 1 2 1 0 1 2
 0 2], shape=(150,), dtype=int32)

3.划分训练集/测试集

x_train = x_data[:-30]
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]
print(y_train)
tf.Tensor(
[1 1 2 0 0 0 2 0 0 0 2 2 2 0 2 2 0 1 2 2 1 0 2 0 2 1 1 0 1 0 2 2 2 2 2 1 0
 2 1 1 2 2 1 2 1 0 0 1 2 0 0 0 1 1 0 2 1 1 1 0 0 0 2 0 1 1 2 1 2 2 0 2 2 1
 1 0 2 1 0 1 2 2 1 0 2 0 0 1 0 1 2 0 0 2 0 0 1 1 1 2 2 2 1 0 1 1 0 1 0 1 0
 1 0 2 1 1 1 2 2 1], shape=(120,), dtype=int32)

4.配对成[特征值,标签]对

划分batch,之后每次喂入一个batch训练

train_db = tf.data.Dataset.from_tensor_slices((x_train,y_train)).batch(32)    # 打包成30个batch
test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test)).batch(32)
print(train_db)
print(test_db)
<BatchDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float64, name=None), TensorSpec(shape=(None,), dtype=tf.int32, name=None))>
<BatchDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float64, name=None), TensorSpec(shape=(None,), dtype=tf.int32, name=None))>
list(train_db.as_numpy_iterator())
[(array([[5.6, 3. , 4.5, 1.5],
         [4.9, 2.4, 3.3, 1. ],
         [6.7, 3. , 5.2, 2.3],
         [4.6, 3.2, 1.4, 0.2],
         [5.1, 3.5, 1.4, 0.2],
         [5.4, 3.7, 1.5, 0.2],
         [6.8, 3.2, 5.9, 2.3],
         [5.3, 3.7, 1.5, 0.2],
         [4.6, 3.4, 1.4, 0.3],
         [4.8, 3. , 1.4, 0.1],
         [6.5, 3. , 5.2, 2. ],
         [4.9, 2.5, 4.5, 1.7],
         [6.8, 3. , 5.5, 2.1],
         [4.8, 3. , 1.4, 0.3],
         [5.7, 2.5, 5. , 2. ],
         [7.9, 3.8, 6.4, 2. ],
         [5. , 3.4, 1.6, 0.4],
         [5.9, 3.2, 4.8, 1.8],
         [6. , 2.2, 5. , 1.5],
         [7.7, 2.8, 6.7, 2. ],
         [5.5, 2.3, 4. , 1.3],
         [4.9, 3. , 1.4, 0.2],
         [5.8, 2.8, 5.1, 2.4],
         [5.4, 3.4, 1.5, 0.4],
         [6.4, 2.8, 5.6, 2.2],
         [5.8, 2.6, 4. , 1.2],
         [5.6, 2.7, 4.2, 1.3],
         [4.9, 3.1, 1.5, 0.2],
         [6.7, 3.1, 4.7, 1.5],
         [5.5, 4.2, 1.4, 0.2],
         [7.2, 3. , 5.8, 1.6],
         [6.3, 2.9, 5.6, 1.8]]),
  array([1, 1, 2, 0, 0, 0, 2, 0, 0, 0, 2, 2, 2, 0, 2, 2, 0, 1, 2, 2, 1, 0,
         2, 0, 2, 1, 1, 0, 1, 0, 2, 2])),
 (array([[6.3, 2.5, 5. , 1.9],
         [6.9, 3.1, 5.1, 2.3],
         [6.4, 3.2, 5.3, 2.3],
         [6.7, 3.1, 4.4, 1.4],
         [5. , 3.2, 1.2, 0.2],
         [6.3, 3.3, 6. , 2.5],
         [6. , 2.2, 4. , 1. ],
         [6.3, 2.5, 4.9, 1.5],
         [6.5, 3. , 5.8, 2.2],
         [7.2, 3.6, 6.1, 2.5],
         [6.7, 3. , 5. , 1.7],
         [6. , 3. , 4.8, 1.8],
         [5.7, 2.8, 4.1, 1.3],
         [5.7, 4.4, 1.5, 0.4],
         [5.4, 3.9, 1.7, 0.4],
         [5.6, 2.5, 3.9, 1.1],
         [7.2, 3.2, 6. , 1.8],
         [4.7, 3.2, 1.6, 0.2],
         [5.1, 3.4, 1.5, 0.2],
         [5. , 3.6, 1.4, 0.2],
         [5.5, 2.4, 3.8, 1.1],
         [5.7, 2.8, 4.5, 1.3],
         [5.2, 3.4, 1.4, 0.2],
         [7.6, 3. , 6.6, 2.1],
         [6.3, 3.3, 4.7, 1.6],
         [5. , 2.3, 3.3, 1. ],
         [5.5, 2.6, 4.4, 1.2],
         [4.8, 3.4, 1.9, 0.2],
         [5.5, 3.5, 1.3, 0.2],
         [5.1, 3.5, 1.4, 0.3],
         [5.8, 2.7, 5.1, 1.9],
         [4.6, 3.6, 1. , 0.2]]),
  array([2, 2, 2, 1, 0, 2, 1, 1, 2, 2, 1, 2, 1, 0, 0, 1, 2, 0, 0, 0, 1, 1,
         0, 2, 1, 1, 1, 0, 0, 0, 2, 0])),
 (array([[5.2, 2.7, 3.9, 1.4],
         [5.7, 3. , 4.2, 1.2],
         [6.3, 2.7, 4.9, 1.8],
         [6. , 2.9, 4.5, 1.5],
         [6.7, 3.3, 5.7, 2.5],
         [6.7, 3.3, 5.7, 2.1],
         [4.6, 3.1, 1.5, 0.2],
         [6.9, 3.1, 5.4, 2.1],
         [6.2, 2.8, 4.8, 1.8],
         [5.5, 2.5, 4. , 1.3],
         [6.1, 2.9, 4.7, 1.4],
         [5.7, 3.8, 1.7, 0.3],
         [5.6, 2.8, 4.9, 2. ],
         [6.6, 3. , 4.4, 1.4],
         [4.7, 3.2, 1.3, 0.2],
         [6.6, 2.9, 4.6, 1.3],
         [6.1, 3. , 4.9, 1.8],
         [6.3, 2.8, 5.1, 1.5],
         [6.5, 2.8, 4.6, 1.5],
         [5.1, 3.7, 1.5, 0.4],
         [7.4, 2.8, 6.1, 1.9],
         [4.9, 3.1, 1.5, 0.1],
         [4.8, 3.1, 1.6, 0.2],
         [5.5, 2.4, 3.7, 1. ],
         [5.2, 4.1, 1.5, 0.1],
         [5.4, 3. , 4.5, 1.5],
         [6.7, 3.1, 5.6, 2.4],
         [5.1, 3.8, 1.5, 0.3],
         [4.4, 3.2, 1.3, 0.2],
         [7.1, 3. , 5.9, 2.1],
         [5. , 3.3, 1.4, 0.2],
         [5. , 3.5, 1.6, 0.6]]),
  array([1, 1, 2, 1, 2, 2, 0, 2, 2, 1, 1, 0, 2, 1, 0, 1, 2, 2, 1, 0, 2, 0,
         0, 1, 0, 1, 2, 0, 0, 2, 0, 0])),
 (array([[6.4, 3.2, 4.5, 1.5],
         [6.1, 3. , 4.6, 1.4],
         [6.3, 2.3, 4.4, 1.3],
         [6.5, 3. , 5.5, 1.8],
         [7.7, 3.8, 6.7, 2.2],
         [6.3, 3.4, 5.6, 2.4],
         [6.4, 2.9, 4.3, 1.3],
         [5. , 3. , 1.6, 0.2],
         [5.8, 2.7, 4.1, 1. ],
         [5.8, 2.7, 3.9, 1.2],
         [4.8, 3.4, 1.6, 0.2],
         [6.9, 3.1, 4.9, 1.5],
         [5.8, 4. , 1.2, 0.2],
         [6. , 3.4, 4.5, 1.6],
         [5.4, 3.9, 1.3, 0.4],
         [6.1, 2.8, 4. , 1.3],
         [5.4, 3.4, 1.7, 0.2],
         [7.7, 3. , 6.1, 2.3],
         [5.1, 2.5, 3. , 1.1],
         [5.6, 3. , 4.1, 1.3],
         [6.1, 2.8, 4.7, 1.2],
         [6.2, 3.4, 5.4, 2.3],
         [6.4, 2.8, 5.6, 2.1],
         [5.7, 2.6, 3.5, 1. ]]),
  array([1, 1, 1, 2, 2, 2, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 2, 1, 1, 1, 2,
         2, 1]))]

5.定义神经网络中所有可训练的参数

# 搭一个1层的神经网络
w1 = tf.Variable(tf.random.truncated_normal([4,3],stddev=0.1,seed=1))
b1 = tf.Variable(tf.random.truncated_normal([3],stddev=0.1,seed=1))
# w2 = tf.Variable(tf.random.truncated_normal([5,3],stddev=0.1,seed=1))
# b2 = tf.Variable(tf.random.truncated_normal([3],stddev=0.1,seed=1))


鸢尾花数据集分类问题(3)https://developer.aliyun.com/article/1540970

目录
相关文章
|
4月前
|
机器学习/深度学习 自然语言处理 算法
什么是数据集的分类?
【7月更文挑战第10天】什么是数据集的分类?
474 1
|
5月前
鸢尾花数据集分类问题(3)
鸢尾花数据集分类问题
31 2
|
5月前
鸢尾花数据集分类问题(1)
鸢尾花数据集分类问题
35 1
|
5月前
鸢尾花数据集分类问题(4)
鸢尾花数据集分类问题
24 0
|
6月前
|
机器学习/深度学习 数据可视化 数据库
R语言对MNIST数据集分析:探索手写数字分类
R语言对MNIST数据集分析:探索手写数字分类
|
6月前
|
数据可视化 算法 数据挖掘
R语言鸢尾花iris数据集的层次聚类分析
R语言鸢尾花iris数据集的层次聚类分析
|
机器学习/深度学习 Python
【统计学习方法】K近邻对鸢尾花(iris)数据集进行多分类
【统计学习方法】K近邻对鸢尾花(iris)数据集进行多分类
234 0
|
机器学习/深度学习 自然语言处理
(imdb数据集)电影评论分类实战:二分类问题
(imdb数据集)电影评论分类实战:二分类问题
|
数据采集 机器学习/深度学习 Python
【统计学习方法】朴素贝叶斯对鸢尾花(iris)数据集进行训练预测
【统计学习方法】朴素贝叶斯对鸢尾花(iris)数据集进行训练预测
392 0
【统计学习方法】朴素贝叶斯对鸢尾花(iris)数据集进行训练预测
|
机器学习/深度学习 Python
【统计学习方法】感知机对鸢尾花(iris)数据集进行二分类
【统计学习方法】感知机对鸢尾花(iris)数据集进行二分类
723 0
【统计学习方法】感知机对鸢尾花(iris)数据集进行二分类