构造神经网络
神经网络是一种仿生学原理的机器学习算法,灵感来源于人脑的神经系统。它由多个神经元(或称为节点)组成,这些神经元通过连接权重形成复杂的网络结构,用来学习和提取输入数据的特征,并用于分类、回归、聚类等任务。
注明:该代码用来训练一个神经网络,网络拟合y = x^2-0.5+noise,该神经网络的结构是输入层为一个神经元,隐藏层为十个神经元,输出层为一个神经元
1. 导入相关库
# 导入相关库 import tensorflow as tf # 用来构造神经网络 import numpy as np # 用来构造数据结构和处理数据模块
这段代码使用了两个 Python 模块:
tensorflow:这是 Google 开源的机器学习框架,用来构造神经网络和训练模型。numpy:这是 Python 中用于矩阵/数组运算的基础库,用来构造数据结构和处理数据。
具体来说:
import tensorflow as tf导入 TensorFlow 库并给它起个别名tf。import numpy as np导入 NumPy 库并给它起个别名np。
2. 定义一个层
# 定义一个层 def add_layer(inputs, in_size, out_size, activation_function=None): # 定义一个层,其中inputs为输入,in_size为上一层神经元数,out_size为该层神经元数 # activation_function为激励函数 Weights = tf.Variable(tf.random_normal([in_size, out_size])) # 初始权重随机生成比较好,in_size,out_size为该权重维度 biases = tf.Variable(tf.zeros([1, out_size]) + 0.1) # 偏置 Wx_plus_b = tf.matmul(inputs, Weights) + biases # matmul为矩阵里的函数相乘 if activation_function is None: outputs = Wx_plus_b # 如果激活函数为空,则不激活,保持数据 else: outputs = activation_function(Wx_plus_b) # 如果激活函数不为空,则激活,并且返回激活后的值 return outputs # 返回激活后的值
这段代码定义了一个函数 add_layer,用于添加一层神经网络。
代码中的参数解释如下:
inputs:该层的输入。in_size:该层的输入维度,即上一层神经元数。out_size:该层的输出维度,即该层神经元数。
activation_function:该层使用的激活函数,可以为空。
函数内部逻辑:
Weights = tf.Variable(tf.random_normal([in_size, out_size])):定义该层的权重,使用随机生成的正态分布数据,维度为[in_size, out_size]。biases = tf.Variable(tf.zeros([1, out_size]) + 0.1):定义该层的偏置,使用全0矩阵,维度为[1, out_size],并且加上0.1,以避免点评中出现0的情况。Wx_plus_b = tf.matmul(inputs, Weights) + biases:使用矩阵乘法计算该层的输出,即加权和加上偏置。if activation_function is None::如果激活函数为空,则直接将加权和加上偏置的结果作为该层的输出。else::否则,对加权和加上偏置的结果进行激活函数的处理,并将处理结果作为该层的输出。return outputs:返回该层的输出。
总的来说,该函数的作用是创建一个神经网络层,将输入经过加权和加上偏置的运算,并使用激活函数得到输出。
3. 构造数据集
# 构造一些样本,用来训练神经网络 x_data = np.linspace(-1, 1, 300)[:, np.newaxis] # 值为(-1,1)之间的数,有300个 noise = np.random.normal(0, 0.05, x_data.shape) x_data
array([[-1. ], [-0.99331104], [-0.98662207], [-0.97993311], [-0.97324415], [-0.96655518], [-0.95986622], [-0.95317726], [-0.94648829], [-0.93979933], [-0.93311037], [-0.9264214 ], [-0.91973244], [-0.91304348], [-0.90635452], [-0.89966555], [-0.89297659], [-0.88628763], [-0.87959866], [-0.8729097 ], [-0.86622074], [-0.85953177], [-0.85284281], [-0.84615385], [-0.83946488], [-0.83277592], [-0.82608696], [-0.81939799], [-0.81270903], [-0.80602007], [-0.7993311 ], [-0.79264214], [-0.78595318], [-0.77926421], [-0.77257525], [-0.76588629], [-0.75919732], [-0.75250836], [-0.7458194 ], [-0.73913043], [-0.73244147], [-0.72575251], [-0.71906355], [-0.71237458], [-0.70568562], [-0.69899666], [-0.69230769], [-0.68561873], [-0.67892977], [-0.6722408 ], [-0.66555184], [-0.65886288], [-0.65217391], [-0.64548495], [-0.63879599], [-0.63210702], [-0.62541806], [-0.6187291 ], [-0.61204013], [-0.60535117], [-0.59866221], [-0.59197324], [-0.58528428], [-0.57859532], [-0.57190635], [-0.56521739], [-0.55852843], [-0.55183946], [-0.5451505 ], [-0.53846154], [-0.53177258], [-0.52508361], [-0.51839465], [-0.51170569], [-0.50501672], [-0.49832776], [-0.4916388 ], [-0.48494983], [-0.47826087], [-0.47157191], [-0.46488294], [-0.45819398], [-0.45150502], [-0.44481605], [-0.43812709], [-0.43143813], [-0.42474916], [-0.4180602 ], [-0.41137124], [-0.40468227], [-0.39799331], [-0.39130435], [-0.38461538], [-0.37792642], [-0.37123746], [-0.36454849], [-0.35785953], [-0.35117057], [-0.34448161], [-0.33779264], [-0.33110368], [-0.32441472], [-0.31772575], [-0.31103679], [-0.30434783], [-0.29765886], [-0.2909699 ], [-0.28428094], [-0.27759197], [-0.27090301], [-0.26421405], [-0.25752508], [-0.25083612], [-0.24414716], [-0.23745819], [-0.23076923], [-0.22408027], [-0.2173913 ], [-0.21070234], [-0.20401338], [-0.19732441], [-0.19063545], [-0.18394649], [-0.17725753], [-0.17056856], [-0.1638796 ], [-0.15719064], [-0.15050167], [-0.14381271], [-0.13712375], [-0.13043478], [-0.12374582], [-0.11705686], [-0.11036789], [-0.10367893], [-0.09698997], [-0.090301 ], [-0.08361204], [-0.07692308], [-0.07023411], [-0.06354515], [-0.05685619], [-0.05016722], [-0.04347826], [-0.0367893 ], [-0.03010033], [-0.02341137], [-0.01672241], [-0.01003344], [-0.00334448], [ 0.00334448], [ 0.01003344], [ 0.01672241], [ 0.02341137], [ 0.03010033], [ 0.0367893 ], [ 0.04347826], [ 0.05016722], [ 0.05685619], [ 0.06354515], [ 0.07023411], [ 0.07692308], [ 0.08361204], [ 0.090301 ], [ 0.09698997], [ 0.10367893], [ 0.11036789], [ 0.11705686], [ 0.12374582], [ 0.13043478], [ 0.13712375], [ 0.14381271], [ 0.15050167], [ 0.15719064], [ 0.1638796 ], [ 0.17056856], [ 0.17725753], [ 0.18394649], [ 0.19063545], [ 0.19732441], [ 0.20401338], [ 0.21070234], [ 0.2173913 ], [ 0.22408027], [ 0.23076923], [ 0.23745819], [ 0.24414716], [ 0.25083612], [ 0.25752508], [ 0.26421405], [ 0.27090301], [ 0.27759197], [ 0.28428094], [ 0.2909699 ], [ 0.29765886], [ 0.30434783], [ 0.31103679], [ 0.31772575], [ 0.32441472], [ 0.33110368], [ 0.33779264], [ 0.34448161], [ 0.35117057], [ 0.35785953], [ 0.36454849], [ 0.37123746], [ 0.37792642], [ 0.38461538], [ 0.39130435], [ 0.39799331], [ 0.40468227], [ 0.41137124], [ 0.4180602 ], [ 0.42474916], [ 0.43143813], [ 0.43812709], [ 0.44481605], [ 0.45150502], [ 0.45819398], [ 0.46488294], [ 0.47157191], [ 0.47826087], [ 0.48494983], [ 0.4916388 ], [ 0.49832776], [ 0.50501672], [ 0.51170569], [ 0.51839465], [ 0.52508361], [ 0.53177258], [ 0.53846154], [ 0.5451505 ], [ 0.55183946], [ 0.55852843], [ 0.56521739], [ 0.57190635], [ 0.57859532], [ 0.58528428], [ 0.59197324], [ 0.59866221], [ 0.60535117], [ 0.61204013], [ 0.6187291 ], [ 0.62541806], [ 0.63210702], [ 0.63879599], [ 0.64548495], [ 0.65217391], [ 0.65886288], [ 0.66555184], [ 0.6722408 ], [ 0.67892977], [ 0.68561873], [ 0.69230769], [ 0.69899666], [ 0.70568562], [ 0.71237458], [ 0.71906355], [ 0.72575251], [ 0.73244147], [ 0.73913043], [ 0.7458194 ], [ 0.75250836], [ 0.75919732], [ 0.76588629], [ 0.77257525], [ 0.77926421], [ 0.78595318], [ 0.79264214], [ 0.7993311 ], [ 0.80602007], [ 0.81270903], [ 0.81939799], [ 0.82608696], [ 0.83277592], [ 0.83946488], [ 0.84615385], [ 0.85284281], [ 0.85953177], [ 0.86622074], [ 0.8729097 ], [ 0.87959866], [ 0.88628763], [ 0.89297659], [ 0.89966555], [ 0.90635452], [ 0.91304348], [ 0.91973244], [ 0.9264214 ], [ 0.93311037], [ 0.93979933], [ 0.94648829], [ 0.95317726], [ 0.95986622], [ 0.96655518], [ 0.97324415], [ 0.97993311], [ 0.98662207], [ 0.99331104], [ 1. ]])
这段代码使用 numpy 库创建了一个一维的数组 x_data。
代码中的参数解释如下:
-1:数组中数的最小值。1:数组中数的最大值。300:数组中数的个数。[:, np.newaxis]:对数组进行转置,转换成二维数组。
函数内部逻辑:
np.linspace(-1, 1, 300):返回一个数值范围在 -1 到 1 之间,总共有 300 个数的等差数列。即生成一个 ndarray 数组,包含 300 个数,分布在 -1 到 1 之间。
[:, np.newaxis]:将一维数组转化成列向量,即让数组的 shape 从(300,)变成(300, 1)。
最终生成的 x_data 是一个二维数组,第一维度为 300,第二维度为 1,表示由 300 个样本组成,每个样本只有一个特征。
# 加入噪声会更贴近真实情况,噪声的值为(0,0.05)之间,结构为x_data一样 y_data = np.square(x_data) - 0.5 + noise # y的结构 y_data
array([[ 0.59535036], [ 0.46017998], [ 0.47144478], [ 0.45083795], [ 0.58438217], [ 0.38570118], [ 0.43550029], [ 0.40597571], [ 0.3357524 ], [ 0.35784864], [ 0.34530231], [ 0.32509701], [ 0.25554733], [ 0.32300801], [ 0.2299959 ], [ 0.35472568], [ 0.31227671], [ 0.30385068], [ 0.29413844], [ 0.18437787], [ 0.28132819], [ 0.25605309], [ 0.23126361], [ 0.23492797], [ 0.18381621], [ 0.10392937], [ 0.13415913], [ 0.14043649], [ 0.11756826], [ 0.12142749], [ 0.12400694], [ 0.08926307], [ 0.15581832], [ 0.16541106], [-0.02582895], [ 0.05924725], [-0.04037454], [ 0.03799003], [ 0.09030832], [ 0.05984324], [-0.06569464], [ 0.07973773], [ 0.04297837], [ 0.05169557], [-0.00096191], [-0.02049573], [-0.03125322], [-0.04545588], [-0.02168901], [ 0.01657517], [-0.04315181], [-0.09123519], [-0.03292835], [-0.1110189 ], [-0.08212792], [-0.10089535], [-0.17406672], [-0.10380731], [-0.10774072], [-0.21283138], [-0.09788435], [-0.10196452], [-0.16439081], [-0.15431978], [-0.17778307], [-0.18428537], [-0.17874028], [-0.10490738], [-0.25076832], [-0.16078044], [-0.21572183], [-0.15624353], [-0.19591988], [-0.31560742], [-0.29593726], [-0.26686787], [-0.2999804 ], [-0.30631065], [-0.35305224], [-0.31295125], [-0.22996255], [-0.22837061], [-0.27266253], [-0.31290802], [-0.37188479], [-0.20765034], [-0.33860431], [-0.31135236], [-0.25249981], [-0.26041048], [-0.31486205], [-0.30253306], [-0.41624795], [-0.40053837], [-0.29939676], [-0.32615377], [-0.37377787], [-0.32222027], [-0.3158838 ], [-0.43880087], [-0.37510637], [-0.46702321], [-0.27058091], [-0.52885151], [-0.4061462 ], [-0.4486374 ], [-0.37819628], [-0.34701947], [-0.32454364], [-0.3901839 ], [-0.43293107], [-0.47881173], [-0.45280819], [-0.49676541], [-0.48955669], [-0.45898691], [-0.37473462], [-0.43801531], [-0.44793655], [-0.57343047], [-0.45262969], [-0.40719677], [-0.45423461], [-0.45053051], [-0.51046881], [-0.41584096], [-0.53328545], [-0.44766406], [-0.50158463], [-0.42676031], [-0.50552613], [-0.36832989], [-0.48699296], [-0.41614151], [-0.6175621 ], [-0.48304532], [-0.46115021], [-0.40948908], [-0.42017024], [-0.50411757], [-0.44530626], [-0.46895275], [-0.52127771], [-0.50064585], [-0.42210169], [-0.58582837], [-0.52049198], [-0.45332091], [-0.53465815], [-0.5385712 ], [-0.5654201 ], [-0.54471377], [-0.48109194], [-0.44565732], [-0.48112022], [-0.46471786], [-0.5452149 ], [-0.52115601], [-0.50234928], [-0.54885558], [-0.5279981 ], [-0.53893795], [-0.44286416], [-0.45371406], [-0.44633111], [-0.57535678], [-0.62918947], [-0.41877124], [-0.56263956], [-0.51201705], [-0.35016007], [-0.49188897], [-0.55766056], [-0.38963378], [-0.5038024 ], [-0.51949984], [-0.45229896], [-0.49193029], [-0.53472883], [-0.48957523], [-0.35561181], [-0.4622668 ], [-0.39177781], [-0.43448445], [-0.49854629], [-0.49843105], [-0.47704375], [-0.36618194], [-0.45177012], [-0.41497222], [-0.42152064], [-0.48996608], [-0.43010878], [-0.42599962], [-0.2841197 ], [-0.38992082], [-0.43802592], [-0.42448799], [-0.29514676], [-0.37154091], [-0.25426219], [-0.44610678], [-0.37120566], [-0.3531599 ], [-0.34606119], [-0.29637877], [-0.3693284 ], [-0.36651142], [-0.30025118], [-0.31443603], [-0.40824064], [-0.31734053], [-0.40807378], [-0.33792031], [-0.22414921], [-0.37707072], [-0.26776417], [-0.29152204], [-0.34066934], [-0.19037511], [-0.23552614], [-0.2144995 ], [-0.27628531], [-0.27329725], [-0.23910513], [-0.30009859], [-0.30192088], [-0.16403744], [-0.32546893], [-0.25686912], [-0.12515146], [-0.21483097], [-0.12779443], [-0.28748063], [-0.23782354], [-0.16024807], [-0.19062672], [-0.15066097], [-0.19043274], [-0.16583211], [-0.11201314], [-0.05612149], [-0.00847256], [-0.1429705 ], [-0.09595988], [-0.09583441], [-0.01372838], [-0.04818834], [-0.11840653], [ 0.02184166], [-0.07153294], [-0.11556547], [-0.04731049], [-0.10774914], [-0.014642 ], [-0.01470962], [-0.03259555], [-0.04194347], [ 0.08987345], [-0.02027899], [ 0.02418433], [ 0.04298611], [ 0.04130101], [ 0.18010436], [ 0.15480307], [ 0.02719993], [ 0.11508363], [ 0.04309794], [ 0.14060578], [ 0.09377926], [ 0.13887198], [ 0.16148276], [ 0.11398259], [ 0.27887578], [ 0.22775177], [ 0.20749998], [ 0.22107721], [ 0.20854961], [ 0.25411644], [ 0.26561906], [ 0.27540788], [ 0.26946028], [ 0.2390275 ], [ 0.26051795], [ 0.34424064], [ 0.3240088 ], [ 0.38040554], [ 0.35717078], [ 0.31357911], [ 0.43825368], [ 0.35709739], [ 0.48101049], [ 0.36024364], [ 0.43253108], [ 0.39268334], [ 0.41942572], [ 0.41196584], [ 0.54435941], [ 0.49840622], [ 0.51627957]])
这段代码是根据 x_data 生成对应的 y_data,并加入了一些噪声。
具体实现中,首先利用 np.square(x_data) 将 x_data 中的每个元素平方,然后减去一个常数 0.5,最后加上一些噪声,生成与 x_data 形状相同的 y_data 数组。
由于 x_data 是一个二维数组,y_data 需要与它形状相同,因此 y_data 也是一个二维数组,包含 300 个样本和每个样本的输出值。
4. 定义基本模型
# 定义placeholder用来输入数据到神经网络,其中1表只有一个特征,也就是维度为一维数据 xs = tf.placeholder(tf.float32, [None, 1]) ys = tf.placeholder(tf.float32, [None, 1]) # add hidden layer l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu) # add output layer prediction = add_layer(l1, 10, 1, activation_function=None) # 代价函数,reduce_mean为求均值,reduce_sum为求和,reduction_indices为数据处理的维度 loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1])) # 将代价函数传到梯度下降,学习速率为0.1,这里包含权重的训练,会更新权重 train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
这段代码定义了一个神经网络模型。
首先,使用 tf.placeholder 创建两个占位符 xs 和 ys,分别用来输入训练数据和真实标签,其中 None 表示样本数量是不确定的,只确定数据的维度是 1 维。
接下来,调用 add_layer 函数添加一个隐藏层,输入为 xs,神经元个数为 10 个,激活函数为 ReLU。然后再调用一次 add_layer 函数添加一个输出层,输入为隐藏层的输出,神经元个数为 1 个,激活函数为 None(也就是不使用激活函数)。
然后,定义了一个代价函数 loss,用来衡量预测值与真实标签之间的差距,这里选用的是 mean square error(均方误差)作为代价函数。具体实现中,使用 tf.square 计算每个样本的预测值与真实标签之间的差距,然后使用 tf.reduce_mean 计算所有样本的差距的平均值。
最后,使用 tf.train.GradientDescentOptimizer 创建一个优化器,设定学习速率为 0.1,然后调用 minimize 方法去最小化代价函数 loss,这里会更新神经网络的权重和偏置,训练模型使得预测值与真实标签不断接近。
5. 变量初始化原文链接:
# important step # tf.initialize_all_variables() no long valid from # 2017-03-02 if using tensorflow >= 0.12 # 变量初始化 if int((tf.__version__).split('.')[1]) < 12: init = tf.initialize_all_variables() else: init = tf.global_variables_initializer() sess = tf.Session() # 打开TensorFlow sess.run(init) # 执行变量初始化
这段代码主要是进行了 TensorFlow 的初始化操作。由于 TensorFlow 的版本问题,原来的 tf.initialize_all_variables() 不再被支持,改为了 tf.global_variables_initializer()。然后创建了一个 tf.Session 对象 sess,用来执行 TensorFlow 中定义的操作。
最后,执行 init 操作进行变量的初始化。这里会将之前定义的变量(包括权重和偏置)都初始化为一些随机值,用来开始训练模型。
6. 开始训练
for i in range(1000): # 梯度下降迭代一千次 # training sess.run(train_step, feed_dict={xs: x_data, ys: y_data}) # 执行梯度下降算法,并且将样本喂给损失函数 if i % 50 == 0: # 每50次迭代输出代价函数的值 print(sess.run(loss, feed_dict={xs: x_data, ys: y_data}))
0.18214862 0.010138167 0.0071248626 0.0069830194 0.0068635535 0.0067452225 0.006626569 0.0065121166 0.0064035906 0.006295418 0.0061897114 0.0060903295 0.005990808 0.0058959606 0.0058057955 0.0057200184 0.005637601 0.0055605737 0.0054863705 0.005413457
这段代码是用来训练模型的,实现了梯度下降的过程。循环了 1000 次,每 50 次迭代输出一次代价函数 loss 的值。在每次迭代中,通过 sess.run(train_step, feed_dict={xs: x_data, ys: y_data}) 执行了一次梯度下降,并把样本传入损失函数中进行计算。
这个训练过程中的输出可以用来观察代价函数的变化情况,如果随着迭代的进行,代价函数的值逐渐减小,那么就表示模型的训练效果越来越好,模型越来越能够准确地预测目标变量。