使用Keras完成线性回归
Keras是一款基于Python的深度学习框架,以Tensorflow、Theano和CNTK作为后端,由François Chollet开发和维护,其目标是使深度学习模型的实现变得快速、简单。它的设计理念是用户友好、可扩展、易于调试和实验。
Keras提供了一系列高级API和便捷的工具,使得用户可以快速构建和训练深度学习模型,而不必关注底层的细节。Keras支持各种类型的网络结构,包括卷积神经网络、循环神经网络、自编码器等,并且可以轻松地在不同的数据集上进行训练和测试。
Keras的主要特点有:
- 简单易用,快速上手:Keras提供了简单易用的API,用户只需几行代码就能实现复杂的深度学习模型。
- 支持多种后端:Keras可以用Tensorflow、Theano和CNTK作为后端,用户可以根据自己的需要选择合适的后端。
- 高度可扩展:Keras提供了模块化的API,用户可以根据需要添加自定义层和函数,以及修改现有的代码。
- 方便的调试和实验:Keras提供了实时可视化的工具,方便用户查看模型的训练情况和测试结果,并且支持各种回调函数,例如早期停止、学习率调整等。
- 支持GPU加速:Keras可以利用GPU进行计算,加速深度学习模型的训练和推断过程。
总之,Keras是一款优秀的深度学习框架,它使得深度学习模型的构建和训练变得更加简单和快速,可以帮助用户更加专注于模型的设计和应用。
1. 导入Keras库
import warnings warnings.filterwarnings("ignore") import numpy as np np.random.seed(1337) from keras.models import Sequential from keras.layers import Dense from sklearn.metrics import r2_score import matplotlib.pyplot as plt
Using TensorFlow backend.
2. 创建数据集
# 创建数据集 # 在[-1,1]的区间内等间隔创建200个样本数 X = np.linspace(-1, 1, 200) X
array([-1. , -0.98994975, -0.9798995 , -0.96984925, -0.95979899, -0.94974874, -0.93969849, -0.92964824, -0.91959799, -0.90954774, -0.89949749, -0.88944724, -0.87939698, -0.86934673, -0.85929648, -0.84924623, -0.83919598, -0.82914573, -0.81909548, -0.80904523, -0.79899497, -0.78894472, -0.77889447, -0.76884422, -0.75879397, -0.74874372, -0.73869347, -0.72864322, -0.71859296, -0.70854271, -0.69849246, -0.68844221, -0.67839196, -0.66834171, -0.65829146, -0.64824121, -0.63819095, -0.6281407 , -0.61809045, -0.6080402 , -0.59798995, -0.5879397 , -0.57788945, -0.5678392 , -0.55778894, -0.54773869, -0.53768844, -0.52763819, -0.51758794, -0.50753769, -0.49748744, -0.48743719, -0.47738693, -0.46733668, -0.45728643, -0.44723618, -0.43718593, -0.42713568, -0.41708543, -0.40703518, -0.39698492, -0.38693467, -0.37688442, -0.36683417, -0.35678392, -0.34673367, -0.33668342, -0.32663317, -0.31658291, -0.30653266, -0.29648241, -0.28643216, -0.27638191, -0.26633166, -0.25628141, -0.24623116, -0.2361809 , -0.22613065, -0.2160804 , -0.20603015, -0.1959799 , -0.18592965, -0.1758794 , -0.16582915, -0.15577889, -0.14572864, -0.13567839, -0.12562814, -0.11557789, -0.10552764, -0.09547739, -0.08542714, -0.07537688, -0.06532663, -0.05527638, -0.04522613, -0.03517588, -0.02512563, -0.01507538, -0.00502513, 0.00502513, 0.01507538, 0.02512563, 0.03517588, 0.04522613, 0.05527638, 0.06532663, 0.07537688, 0.08542714, 0.09547739, 0.10552764, 0.11557789, 0.12562814, 0.13567839, 0.14572864, 0.15577889, 0.16582915, 0.1758794 , 0.18592965, 0.1959799 , 0.20603015, 0.2160804 , 0.22613065, 0.2361809 , 0.24623116, 0.25628141, 0.26633166, 0.27638191, 0.28643216, 0.29648241, 0.30653266, 0.31658291, 0.32663317, 0.33668342, 0.34673367, 0.35678392, 0.36683417, 0.37688442, 0.38693467, 0.39698492, 0.40703518, 0.41708543, 0.42713568, 0.43718593, 0.44723618, 0.45728643, 0.46733668, 0.47738693, 0.48743719, 0.49748744, 0.50753769, 0.51758794, 0.52763819, 0.53768844, 0.54773869, 0.55778894, 0.5678392 , 0.57788945, 0.5879397 , 0.59798995, 0.6080402 , 0.61809045, 0.6281407 , 0.63819095, 0.64824121, 0.65829146, 0.66834171, 0.67839196, 0.68844221, 0.69849246, 0.70854271, 0.71859296, 0.72864322, 0.73869347, 0.74874372, 0.75879397, 0.76884422, 0.77889447, 0.78894472, 0.79899497, 0.80904523, 0.81909548, 0.82914573, 0.83919598, 0.84924623, 0.85929648, 0.86934673, 0.87939698, 0.88944724, 0.89949749, 0.90954774, 0.91959799, 0.92964824, 0.93969849, 0.94974874, 0.95979899, 0.96984925, 0.9798995 , 0.98994975, 1. ])
# 将数据集随机化 np.random.shuffle(X) X
array([-0.70854271, 0.1758794 , -0.30653266, 0.74874372, -0.02512563, 0.33668342, -0.85929648, 0.01507538, -0.13567839, 0.72864322, 0.24623116, -0.74874372, -0.78894472, 0.50753769, 0.03517588, 0.35678392, -0.55778894, 0.2361809 , -0.25628141, -0.44723618, 0.2160804 , -0.43718593, -0.64824121, 0.69849246, -0.03517588, -0.45728643, 0.86934673, 0.73869347, 0.53768844, -0.67839196, -0.75879397, 0.55778894, 0.28643216, -0.05527638, -0.86934673, 0.1959799 , -0.57788945, -0.9798995 , -0.6080402 , -0.63819095, 0.84924623, 0.41708543, 0.13567839, 0.79899497, -0.47738693, 0.46733668, 0.59798995, -0.80904523, -0.98994975, -0.36683417, -0.5678392 , -0.00502513, -0.53768844, -0.37688442, -0.65829146, -0.1959799 , 0.06532663, 0.44723618, -0.01507538, -0.6281407 , 0.02512563, -0.71859296, -0.14572864, -0.46733668, 0.07537688, 0.85929648, 0.76884422, 0.40703518, -0.68844221, 0.68844221, -0.29648241, 0.66834171, -0.95979899, -0.33668342, 0.26633166, -0.82914573, 1. , -0.5879397 , -0.69849246, -0.20603015, 0.63819095, -0.88944724, -0.40703518, -0.32663317, 0.15577889, -0.41708543, 0.10552764, 0.20603015, -0.04522613, 0.00502513, -0.31658291, 0.43718593, 0.42713568, 0.45728643, -0.59798995, -0.66834171, 0.83919598, 0.75879397, -0.24623116, 0.71859296, -0.92964824, 0.39698492, 0.61809045, -0.84924623, -0.87939698, -0.96984925, 0.87939698, 0.6281407 , 0.25628141, 0.27638191, 0.12562814, 0.09547739, -0.89949749, 0.80904523, -0.16582915, -0.12562814, 0.30653266, 0.49748744, 0.5879397 , -0.51758794, -0.10552764, 0.54773869, -0.94974874, 0.92964824, 0.16582915, -0.83919598, -0.35678392, -0.48743719, 0.08542714, -0.61809045, 0.18592965, 0.57788945, 0.65829146, 0.38693467, 0.91959799, -0.26633166, -0.50753769, -1. , -0.54773869, 0.6080402 , -0.49748744, -0.22613065, 0.9798995 , 0.98994975, 0.5678392 , 0.32663317, 0.64824121, -0.52763819, 0.36683417, 0.81909548, -0.11557789, 0.31658291, -0.2160804 , 0.95979899, 0.77889447, -0.73869347, -0.81909548, -0.79899497, 0.78894472, 0.88944724, -0.2361809 , 0.37688442, 0.70854271, 0.22613065, -0.28643216, -0.38693467, 0.90954774, -0.91959799, 0.48743719, -0.42713568, -0.08542714, 0.11557789, -0.18592965, 0.47738693, -0.39698492, -0.34673367, 0.04522613, 0.05527638, 0.93969849, -0.77889447, -0.93969849, -0.06532663, -0.72864322, 0.29648241, 0.52763819, -0.76884422, 0.94974874, 0.82914573, 0.34673367, -0.90954774, -0.27638191, -0.15577889, -0.1758794 , 0.14572864, -0.09547739, 0.96984925, 0.67839196, -0.07537688, 0.89949749, 0.51758794])
# 假设真实模型为:Y=0.5X+2 Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200,)) Y
array([1.66851812, 2.12220988, 1.91611873, 2.38979647, 1.96473269, 2.11662688, 1.58217043, 2.05326658, 1.95885373, 2.4277956 , 2.13544689, 1.68732448, 1.66384243, 2.2702853 , 2.03148986, 2.14968674, 1.76442495, 2.10802586, 1.93269542, 1.81936289, 2.15190248, 1.83941395, 1.71399197, 2.21820555, 1.97918099, 1.79781646, 2.43645587, 2.31211201, 2.21764353, 1.71912829, 1.64285239, 2.2663785 , 2.11081029, 2.09338152, 1.5614153 , 2.19655545, 1.72824772, 1.56444412, 1.72673075, 1.67311017, 2.39817488, 2.12624087, 2.07791136, 2.40515644, 1.80701389, 2.16050089, 2.30373845, 1.57656517, 1.52482139, 1.7639545 , 1.76787463, 2.01204511, 1.74877623, 1.86751173, 1.67509082, 1.95941218, 2.0126989 , 2.31574759, 2.04672223, 1.73762178, 1.97249596, 1.65257838, 1.98435822, 1.74193776, 2.05272917, 2.41693508, 2.37609913, 2.24686996, 1.61790402, 2.37607665, 1.82677368, 2.29512653, 1.52756173, 1.79404414, 2.08314 , 1.5209276 , 2.48034115, 1.7821867 , 1.60377021, 1.82345627, 2.23840132, 1.50174227, 1.85127905, 1.92372432, 1.95433662, 1.8146093 , 1.96513404, 2.0227501 , 1.97564664, 2.09893966, 1.95392005, 2.2089975 , 2.26074219, 2.24742979, 1.75936195, 1.69145596, 2.46801952, 2.40938521, 1.98369075, 2.37509171, 1.53026033, 2.24305926, 2.33309562, 1.49913881, 1.48743005, 1.54075518, 2.33130062, 2.37463005, 2.19387461, 2.20970603, 2.04719149, 2.04105128, 1.48410805, 2.34714158, 1.95061571, 1.89473245, 2.26596278, 2.22430597, 2.29984983, 1.7894671 , 1.85995514, 2.31688729, 1.53417344, 2.39777465, 2.12853793, 1.47736812, 1.90180229, 1.73086567, 2.03772387, 1.67243511, 2.10115733, 2.26944612, 2.37404859, 2.22042332, 2.4948031 , 1.80153666, 1.72069013, 1.44829544, 1.77678155, 2.24291992, 1.73557503, 1.79249737, 2.52580388, 2.46810975, 2.34211232, 2.22144569, 2.31945172, 1.72814133, 2.17318812, 2.43560932, 1.9662451 , 2.14319385, 1.83150682, 2.48805089, 2.28374904, 1.63645718, 1.57901687, 1.61041853, 2.40884706, 2.37339631, 1.90728817, 2.09065413, 2.36836694, 2.05400262, 1.87764304, 1.83547711, 2.45064964, 1.46324772, 2.2429919 , 1.75954149, 1.97326923, 2.08379661, 2.04616096, 2.3161197 , 1.81470671, 1.8188581 , 2.11349671, 2.05477704, 2.39622142, 1.61281075, 1.56914576, 1.96947616, 1.56645219, 2.08002605, 2.2185357 , 1.54079134, 2.42384819, 2.41198434, 2.0570266 , 1.55142224, 1.83396657, 1.92648666, 1.9143498 , 1.9372014 , 1.92794208, 2.42698754, 2.29871021, 2.03266023, 2.42413239, 2.28286632])
# 绘制数据集(X, Y) plt.scatter(X, Y) plt.show()
3. 划分数据集
# 划分训练集和测试集 X_train, Y_train = X[:160], Y[:160] X_test, Y_test = X[160:], Y[160:]
4. 构造神经网络模型
# 定义一个model # Keras有两种类型的模型,序列模型和函数式模型 # 比较常用的是Sequential,它是单输入单输出的 model = Sequential() # 通过add()方法一层层添加模型 # Dense是全连接层,第一层需要定义输入 model.add(Dense(output_dim=1,input_dim=1)) # 定义完成模型就要训练了,不过训练之前我们需要指定一些训练参数 # 通过compile()方法选择损失函数和优化器 # 这里我们用均方差作为损失函数,随机梯度下降作为优化方法 model.compile(loss='mse', optimizer='sgd')
5. 训练模型
# 开始训练 print('Training ----------') # Keras有很多开始训练的函数,这里用train_on_batch() for step in range(301): cost = model.train_on_batch(X_train,Y_train) if step%100 == 0: print('train cost: ', cost)
Training ---------- WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead. train cost: 4.0225005 train cost: 0.073238626 train cost: 0.00386274 train cost: 0.002643449
6. 测试模型
# 测试训练好的模型 print('Testing ----------') cost = model.evaluate(X_test, Y_test, batch_size = 40) print('test cost: ',cost)
Testing ---------- 40/40 [==============================] - 0s 508us/step test cost: 0.0031367032788693905
7. 分析模型
# 查看训练出的网络参数 # 由于我们网络只有一层,且每次训练的输入只有一个,输出只有一个 # 因此第一层训练出Y=WX+B这个模型,其中W,b为训练出的参数 W, b = model.layers[0].get_weights() print('Weights = ', W, '\nbiases = ', b)
Weights = [[0.4922711]] biases = [1.9995022]
# 画出预测图 Y_pred = model.predict(X_test) plt.scatter(X_test, Y_test) plt.plot(X_test, Y_pred) plt.show()
#使用r2 score评估准确度 pred_acc = r2_score(Y_test, Y_pred) print('pred_acc',pred_acc)
pred_acc 0.9591211310535933
#保存模型 model.save('keras_linear.h5')