【深度学习】实验10 使用Keras完成逻辑回归

简介: 【深度学习】实验10 使用Keras完成逻辑回归

使用Keras完成逻辑回归

Keras是一个开源的深度学习框架,能够高效地实现神经网络和深度学习模型。它由纽约大学的Francois Chollet开发,旨在提供一个简单易用的高层次API,以便开发人员能够快速搭建模型,从而节省时间和精力。Keras能够兼容各种底层深度学习框架,如TensorFlow、Theano和CNTK等。它已经成为深度学习领域中最受欢迎的框架之一,因为它既容易上手又具有灵活性。


Keras的设计初衷是让深度学习变得更容易,更快速地实现从数据到模型的过程。在使用Keras进行深度学习时,您无需编写多行代码来定义神经层、激活函数、优化器和损失函数等超参数,只需一行代码即可。此外,Keras还提供了丰富的预训练模型,可用于处理图像分类、自然语言处理、文本分类和序列分析等任务,从而大大减少了深度学习模型的开发和训练时间。


Keras还具有以下特点:

  1. 简单易用:Keras使用Python编写,提供了简单的API接口,让用户更加关注模型的设计和调整。
  2. 易于扩展:Keras可以兼容多种深度学习框架,如TensorFlow、Theano和CNTK等,能够利用它们的计算能力进行高效的训练和推理。
  3. 快速实现:Keras提供了多种预训练模型,无需从头开始开发模型,快速构建出高质量的深度学习模型。
  4. 支持多种语言:Keras不仅支持Python编程语言,还支持R和Java等其他编程语言。
  5. 开源社区:Keras在GitHub上有庞大的开源社区,拥有丰富的教程和示例,以便开发人员更好地学习和使用。

总之,Keras是一个简单易用、高效实现深度学习模型的框架,能够大大提升深度学习模型的开发和实现效率。

1. 导入Keras库

# 导入相关库
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from sklearn import datasets
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
Using TensorFlow backend.

2. 生成数据集

# 生成样本数据集,两个特征列,两个分类二分类不需要onehot编码,直接将类别转换为0和1,分别代表正样本的概率。
X, y = datasets.make_classification(n_samples = 200, n_features = 2, n_informative = 2, n_redundant = 0,
                                   n_repeated = 0, n_classes = 2, n_clusters_per_class = 1)
X, y
   (array([[ 0.26364611,  0.77250816],
           [ 0.91698377,  0.9802208 ],
           [ 0.82634329,  0.9821341 ],
           [-0.83833456,  0.88223515],
           [ 1.11509338,  0.98632275],
           [ 1.04196821,  0.97892474],
           [ 0.77695264,  1.06320914],
           [-2.16804253,  0.15267335],
           [-1.96973867,  0.99244728],
           [-1.35368845,  1.25840447],
           [-0.52455148,  2.2351536 ],
           [ 1.08554563,  1.03795405],
           [ 0.88261697,  0.97793289],
           [-1.03718795,  0.53830131],
           [ 0.94628633,  0.96289949],
           [ 1.16190683,  1.01806263],
           [-2.07795249,  0.32376505],
           [ 0.9370119 ,  1.01060097],
           [ 0.92750449,  0.98713143],
           [-0.35800128,  1.4498587 ],
           [-0.96709704,  1.77632874],
           [-0.55995817,  1.58782776],
           [ 0.88919948,  1.00133032],
           [ 1.16465115,  1.05117935],
           [-1.6969619 ,  1.80088135],
           [ 1.06292602,  1.04594288],
           [-0.07792111,  0.98391779],
           [-1.05188451,  1.26871626],
           [-0.83494005,  0.93958161],
           [ 1.10371115,  1.03558148],
           [ 0.98674372,  1.04567265],
           [-1.08345028,  1.18601788],
           [-2.06487683,  0.17118219],
           [ 1.02734931,  0.99326938],
           [-0.11345441,  1.08515199],
           [ 0.97705823,  1.01751506],
           [-0.10872522,  0.91580496],
           [-1.27087508, -0.19146954],
           [ 0.87616438,  0.97685435],
           [ 0.89526079,  0.98651642],
           [ 0.96521071,  1.0206381 ],
           [ 1.0530243 ,  0.93365071],
           [ 0.994778  ,  0.99724912],
           [ 0.98176246,  1.03168734],
           [ 0.74458014,  0.97066564],
           [ 0.91748012,  0.9524803 ],
           [-1.92749946,  0.07784549],
           [ 0.7790389 ,  0.95517882],
           [ 0.11824333,  1.81065221],
           [ 0.97490265,  0.95326328],
           [ 1.00355225,  0.96521073],
           [ 1.08398178,  0.97814922],
           [-1.0749128 ,  1.77825305],
           [ 0.74886096,  1.39448605],
           [-0.1950267 ,  1.57178284],
           [ 1.069671  ,  0.97202065],
           [ 0.85757149,  1.01910676],
           [-1.02014343,  1.14016873],
           [-1.25252256,  0.02906454],
           [ 0.93948239,  1.44153932],
           [ 1.28777891,  1.00133477],
           [-1.7010408 ,  0.0821629 ],
           [ 0.8390028 ,  0.97712472],
           [ 0.99480479,  1.05717262],
           [ 1.20707509,  0.97462669],
           [-2.18786288,  1.4515569 ],
           [ 1.16027197,  1.09086817],
           [ 1.02771087,  0.9907291 ],
           [ 0.71829704,  0.98817911],
           [ 0.88605935,  0.99158972],
           [ 1.03589316,  0.99557438],
           [ 1.15489923,  0.95378093],
           [ 1.0668616 ,  0.99316509],
           [ 1.04848333,  1.09471239],
           [-1.05108888, -0.071106  ],
           [-1.19977682,  1.49257613],
           [ 1.12232276,  0.99293853],
           [-0.36977293,  1.59581   ],
           [-0.27363841,  1.46272407],
           [ 1.18075342,  0.95907983],
           [ 1.01486256,  0.97501177],
           [-0.41533403,  1.72366429],
           [-0.18337732,  2.26674615],
           [ 1.06777804,  1.00982417],
           [ 1.17411206,  0.98088369],
           [ 0.95355889,  1.05238272],
           [-0.39459255,  1.97600217],
           [ 0.90103447,  0.94080238],
           [ 0.87268023,  1.00348657],
           [-1.93323667,  1.04826094],
           [ 0.10460058,  1.16348717],
           [-1.85815599,  1.32669461],
           [ 0.90426972,  0.97521677],
           [-0.58409513,  0.9870014 ],
           [-1.74011619, -0.21416096],
           [-1.51931589,  0.34938829],
           [ 1.02631005,  0.99378866],
           [ 1.02869184,  0.99995857],
           [ 0.79862419,  1.00291807],
           [-1.34714457,  0.78937109],
           [-2.54273315,  0.96748855],
           [-1.86729291,  0.37250653],
           [-0.89843699,  0.43898384],
           [-1.83077543,  0.43636701],
           [-0.89141966,  1.57275938],
           [-0.96662858,  0.8196104 ],
           [ 0.87417528,  1.00989496],
           [ 0.93997582,  0.95616278],
           [-1.85338565,  1.00940185],
           [ 0.89565224,  0.95460192],
           [-0.76327569,  0.93526008],
           [-1.78345269,  1.53378105],
           [ 0.77408528,  1.01387371],
           [-1.47669576,  1.43472266],
           [ 1.19417792,  1.0440538 ],
           [ 1.15595665,  0.96823244],
           [ 0.84068935,  1.01792225],
           [ 1.11747629,  1.05722511],
           [ 0.23722569,  1.54396395],
           [-1.24609914,  0.30094681],
           [-0.18745572,  1.04657197],
           [ 0.90607352,  0.96120285],
           [-2.02612   ,  0.44082817],
           [ 0.8762596 ,  1.00607109],
           [ 0.98791921,  1.02441508],
           [-0.65307666,  1.22493946],
           [ 0.94162298,  1.28044258],
           [ 0.8622878 ,  0.99707326],
           [-0.27590245,  1.1547649 ],
           [ 0.99268975,  1.02885589],
           [ 1.0635428 ,  1.03445117],
           [-2.1378345 ,  0.62797163],
           [-1.40559883,  0.26079323],
           [ 1.07732353,  1.01373432],
           [-1.74785838,  1.25425571],
           [-0.51461996,  1.2583831 ],
           [ 1.02632384,  1.00203908],
           [ 0.84413823,  2.99872324],
           [ 1.10319604,  0.9615482 ],
           [ 0.95870127,  1.0461775 ],
           [-1.61872726,  0.55348188],
           [ 1.22219183,  1.00893646],
           [-0.04807925,  1.69061295],
           [-3.86851327, -0.36829707],
           [-0.84318558,  0.71791949],
           [ 0.95549697,  1.02457587],
           [ 0.15484069,  0.80992914],
           [ 1.1947279 ,  1.02301068],
           [-0.88323476,  1.52212056],
           [ 0.82715121,  0.99856576],
           [-0.97808876,  2.01262021],
           [-1.66906556,  0.70668215],
           [ 1.29672679,  0.64929896],
           [-0.45096669,  1.88364922],
           [-2.70110985,  0.36698604],
           [ 1.0795718 ,  1.02443886],
           [ 0.99150574,  0.98348741],
           [-0.65205587,  1.86131659],
           [-0.56754302,  1.87827013],
           [ 1.12356817,  1.06645171],
           [-2.72752499,  0.43018586],
           [-2.74061782, -0.08021407],
           [-0.3200331 ,  1.09683115],
           [ 1.0768664 ,  1.0085724 ],
           [-3.6325113 ,  0.67221516],
           [ 0.25830215,  0.79172286],
           [ 1.07796662,  1.00493526],
           [ 0.89606453,  0.98028498],
           [-0.94518278,  1.52377526],
           [ 0.90935946,  0.90695147],
           [ 1.0148515 ,  1.06783713],
           [ 1.16686534,  0.99312304],
           [-1.31640844,  0.32636521],
           [-1.39485695,  0.47605367],
           [-0.50763796,  2.04039346],
           [-0.58489137,  1.16215935],
           [-1.21643673,  1.16555051],
           [-2.9813908 , -0.02123246],
           [ 1.05056765,  1.0129612 ],
           [ 1.01961575,  1.03539024],
           [ 1.01227271,  0.96751672],
           [ 0.12444867,  1.38342266],
           [ 0.99713663,  0.96095512],
           [ 0.98185855,  0.9941474 ],
           [ 0.92998157,  1.03644759],
           [-0.18646788,  2.02399395],
           [-1.79776907,  0.97067984],
           [-3.23433111,  0.54897531],
           [-2.18617596,  0.33414794],
           [ 1.16844027,  1.01821873],
           [ 1.0428281 ,  1.01154471],
           [ 0.9159169 ,  1.02463567],
           [-1.3578118 ,  0.67183832],
           [-0.58824562,  1.08975919],
           [ 1.01775857,  1.00733938],
           [ 1.14847576,  1.01783862],
           [-1.1115874 ,  0.42278247],
           [ 0.84772713,  0.99733494],
           [ 1.00417018,  0.93763177],
           [ 0.56134549,  1.20390517]]),
    array([1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1,
           0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0,
           0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1,
           0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0,
           0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0,
           1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1,
           1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1,
           1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1,
           1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0,
           0, 1]))

3. 构造神经网络模型

# 构建神经网络模型
model = Sequential()
model.add(Dense(input_dim = 2, units = 1))
model.add(Activation('sigmoid'))
# 选定loss函数和优化器
model.compile(loss = 'binary_crossentropy', optimizer = 'sgd')
WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

4. 训练模型

# 训练过程
print("Training ----------")
for step in range(501):
    cost = model.train_on_batch(X, y)
    if step % 50 == 0:
        print("After %d trainings, the cost: %f" % (step, cost))
Training ----------
After 0 trainings, the cost: 0.370295
After 50 trainings, the cost: 0.349558
After 100 trainings, the cost: 0.331982
After 150 trainings, the cost: 0.316872
After 200 trainings, the cost: 0.303725
After 250 trainings, the cost: 0.292170
After 300 trainings, the cost: 0.281923
After 350 trainings, the cost: 0.272767
After 400 trainings, the cost: 0.264532
After 450 trainings, the cost: 0.257079
After 500 trainings, the cost: 0.250299

5. 测试模型

# 测试过程
print("Testing ----------")
cost = model.evaluate(X, y, batch_size = 40)
print("test cost:", cost)
W, b = model.layers[0].get_weights()
print('Weights = ', W, '\nbiases = ', b)
Testing ----------
200/200 [==============================] - 0s 53us/step
test cost: 0.25016908943653104
Weights =  [[-1.7198342 ]
 [-0.18482684]] 
biases =  [0.47288144]

6. 分析模型

# 将训练结果绘出
Y_pred = model.predict(X)
# 将概率转化为类标号,概率在0-0.5时,转为0,概率在0.5-1时转为1
Y_pred = (Y_pred*2).astype('int')  
# 绘制散点图 参数:x横轴 y纵轴
plt.subplot(2,1,1).scatter(X[:,0], X[:,1], c=Y_pred[:,0])
plt.subplot(2,1,2).scatter(X[:,0], X[:,1], c=y)
plt.show()

168620ef98674f06b86f8c3ff09d9919.png


目录
相关文章
|
2月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
将Keras训练好的.hdf5模型转换为TensorFlow的.pb模型,然后再转换为TensorRT支持的.uff格式,并提供了转换代码和测试步骤。
104 3
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
|
6月前
|
机器学习/深度学习 TensorFlow API
TensorFlow与Keras实战:构建深度学习模型
本文探讨了TensorFlow和其高级API Keras在深度学习中的应用。TensorFlow是Google开发的高性能开源框架,支持分布式计算,而Keras以其用户友好和模块化设计简化了神经网络构建。通过一个手写数字识别的实战案例,展示了如何使用Keras加载MNIST数据集、构建CNN模型、训练及评估模型,并进行预测。案例详述了数据预处理、模型构建、训练过程和预测新图像的步骤,为读者提供TensorFlow和Keras的基础实践指导。
460 59
|
2月前
|
机器学习/深度学习 监控 数据可视化
深度学习中实验、观察与思考的方法与技巧
在深度学习中,实验、观察与思考是理解和改进模型性能的关键环节。
55 5
|
2月前
|
机器学习/深度学习 数据挖掘 知识图谱
深度学习之材料科学中的自动化实验设计
基于深度学习的材料科学中的自动化实验设计是一个新兴领域,旨在通过机器学习模型,尤其是深度学习模型,来优化和自动化材料实验的设计流程。
50 1
|
2月前
|
机器学习/深度学习 移动开发 TensorFlow
深度学习之格式转换笔记(四):Keras(.h5)模型转化为TensorFlow(.pb)模型
本文介绍了如何使用Python脚本将Keras模型转换为TensorFlow的.pb格式模型,包括加载模型、重命名输出节点和量化等步骤,以便在TensorFlow中进行部署和推理。
116 0
|
4月前
|
机器学习/深度学习 存储 算法框架/工具
【深度学习】猫狗识别TensorFlow2实验报告
本文介绍了使用TensorFlow 2进行猫狗识别的实验报告,包括实验目的、采用卷积神经网络(CNN)进行训练的过程,以及如何使用交叉熵作为损失函数来识别猫狗图像数据集。
184 1
|
4月前
|
机器学习/深度学习 算法 测试技术
【深度学习】手写数字识别Tensorflow2实验报告
文章介绍了使用TensorFlow 2进行手写数字识别的实验报告,包括实验目的、采用全连接神经网络模型进行训练的过程、以及如何使用交叉熵作为损失函数来识别MNIST数据集的手写数字。
160 0
|
5月前
|
机器学习/深度学习 数据采集 TensorFlow
深度学习与传统模型的桥梁:Sklearn与Keras的集成应用
【7月更文第24天】在机器学习领域,Scikit-learn(Sklearn)作为经典的传统机器学习库,以其丰富的预处理工具、模型选择和评估方法而闻名;而Keras作为深度学习领域的明星框架,以其简洁易用的API,支持快速构建和实验复杂的神经网络模型。将这两者结合起来,可以实现从传统机器学习到深度学习的无缝过渡,充分发挥各自的优势,打造更强大、更灵活的解决方案。本文将探讨Sklearn与Keras的集成应用,通过实例展示如何在Sklearn的生态系统中嵌入Keras模型,实现模型的训练、评估与优化。
134 0
|
6月前
|
机器学习/深度学习 API TensorFlow
Keras深度学习框架入门与实践
**Keras**是Python的高级神经网络API,支持TensorFlow、Theano和CNTK后端。因其用户友好、模块化和可扩展性受到深度学习开发者欢迎。本文概述了Keras的基础,包括**模型构建**(Sequential和Functional API)、**编译与训练**(选择优化器、损失函数和评估指标)以及**评估与预测**。还提供了一个**代码示例**,展示如何使用Keras构建和训练简单的卷积神经网络(CNN)进行MNIST手写数字分类。最后,强调Keras简化了复杂神经网络的构建和训练过程。【6月更文挑战第7天】
68 7
|
7月前
|
机器学习/深度学习 算法 TensorFlow
TensorFlow 2keras开发深度学习模型实例:多层感知器(MLP),卷积神经网络(CNN)和递归神经网络(RNN)
TensorFlow 2keras开发深度学习模型实例:多层感知器(MLP),卷积神经网络(CNN)和递归神经网络(RNN)