【Python-Keras】keras.fit()和keras.fit_generator()的解析与使用

本文涉及的产品
云解析 DNS,旗舰版 1个月
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
简介: 本文解析了Keras中的`fit()`和`fit_generator()`方法,解释了两者在训练神经网络模型时的区别和使用场景,其中`fit()`适用于数据集较小且无需数据增强时,而`fit_generator()`适用于大数据集或需要数据增强的情况。

1 作用与区别

作用: 用于训练神经网络模型,两者可以完成相同的任务

区别:
.fit()时使用的整个训练数据集可以放入内存,并没有应用数据增强,就是.fit()无需使用Keras生成器(即无需数据参数)

当我们有一个巨大的数据集可容纳到我们的内存中或需要应用数据扩充时,将使用.fit_generator()。就是需要使用Keras生成器去扩充数据等等操作。

2 解析与使用

2.1 keras.fit()

(1)参数介绍

keras.fit(object, #要训练的模型
    x = NULL, #训练数据。可以是向量,数组或矩阵
    y = NULL, #训练标签。可以是向量,数组或矩阵
    batch_size = NULL, #它可以接受任何整数值或NULL,默认情况下,它将设置为32。它指定否。每个梯度样本数
    epochs = 10,#一个整数,我们要训练模型epochs的数量
      verbose = getOption("keras.fit_verbose", default = 1),#指定详细模式(0 =静音,1 =进度栏,2 = 1每行记录)      
      callbacks = NULL, 
      view_metrics = getOption("keras.view_metrics",
      default = "auto"), 
      validation_split = 0, 
      validation_data = NULL,
      shuffle = TRUE, 
      class_weight = NULL, 
      sample_weight = NULL,
      initial_epoch = 0, 
      steps_per_epoch = NULL, #它指定之前执行的步骤总数
      validation_steps = NULL,
  ...)

(2)举例使用

我们首先输入训练数据(Xtrain)和训练标签(Ytrain)。然后,我们使用Keras允许我们的模型以batch_size为32训练100个epoch。

model.fit(Xtrain, Ytrain, batch_size = 32, epochs = 100)

(3)原理讲解

当我们调用.fit()函数时,它会做一些假设:

  • 整个训练集可以放入计算机的随机存取存储器(RAM)中。
  • 调用模型。fit方法第二次不会重新初始化我们已经训练好的权重,这意味着如果需要,我们实际上可连续调用fit以进行调整。
  • 无需使用Keras生成器(即无需数据参数)
  • 原始数据本身就是用于训练我们的网络的,而我们的原始数据只能放入内存中

2.2 keras.fit_generator()

(1)参数介绍

fit_generator(object, #Keras对象模型
    generator, #生成器,其输出必须是以下形式的列表:
            #  - (inputs, targets)    
            #  - (input, targets, sample_weights)
            # 生成器的单个输出进行单个批处理,因此列表中的所有数组 长度必须等于批次的大小。生成器是期望的
            # 遍历其数据无限。有时,它永远不会返回或退出。
    steps_per_epoch, #它指定从生成器采取的步骤总数
    epochs = 1,
      verbose = getOption("keras.fit_verbose", default = 1),
      callbacks = NULL, 
      view_metrics = getOption("keras.view_metrics",
      default = "auto"), 
      validation_data = NULL, #可以是以下的其中一种
              # - 一个 inputs 和 targets 列表
            # - 一个发生器
              # - inputs, targets, 和sample_weights 列表,可用于在任何时期结束后评估任何模型的损失和度量。
      validation_steps = NULL,#仅当validation_data是生成器时,才此参数可以使用。它指定生成器之前从生成器采取的步骤总数,在每个epoch停止,其值=在数据集中验证数据点的总数/验证batch大小。
      class_weight = NULL, 
      max_queue_size = 10, 
      workers = 1,
      initial_epoch = 0)

(2)举例使用

数据增强 是一种从现有训练数据集中人为创建新数据集进行训练的方法,以利用可用数据量来提高深度学习神经网络的性能。这是一种正则化形式,使我们的模型比以前更好地推广。
在这里,我们使用Keras ImageDataGenerator对象将数据增强应用于图像的随机平移,调整大小,旋转等。每一批新数据都会根据提供给ImageDataGenerator的参数进行随机调整。

#通过训练图像生成器执行数据论证
dataAugmentaion = ImageDataGenerator(rotation_range = 30,
                                    zoom_range = 0.20, 
                                    fill_mode =“ nearest”,
                                    shear_range= 0.20,
                                    horizo​​ntal_flip = True, 
                                    width_shift_range = 0.1,
                                    height_shift_range = 0.1)

#训练模型
model.fit_generator(dataAugmentaion.flow(trainX,trainY,batch_size = 32),
                     validate_data =(testX,testY),
                     steps_per_epoch = len(trainX)// 32,
                     epoch= 10)

网络训练10个epoch,默认batch大小为32。
对于较小和较不复杂的数据集,建议使用keras.fit函数,而在处理实际数据集时,并不是那么简单,因为实际数据集的大小很大,很难放入计算机内存中。
处理这些数据集更具挑战性,处理这些数据集的重要步骤是执行数据扩充,以避免模型的过拟合,并提高模型的泛化能力。

(3)原理解析

当调用.fit_generator()函数时,它会做一些假设:

  • Keras首先调用了生成器函数(dataAugmentaion)
  • 生成器函数为.fit_generator()函数提供了32的batch_size。
  • .fit_generator()函数首先接受一批数据集,然后对其进行反向传播,然后更新模型中的权重。
  • 对于指定的epoch数(在本例中为10),将重复此过程。
目录
相关文章
|
2天前
|
数据采集 存储 JSON
从零到一构建网络爬虫帝国:HTTP协议+Python requests库深度解析
在网络数据的海洋中,网络爬虫遵循HTTP协议,穿梭于互联网各处,收集宝贵信息。本文将从零开始,使用Python的requests库,深入解析HTTP协议,助你构建自己的网络爬虫帝国。首先介绍HTTP协议基础,包括请求与响应结构;然后详细介绍requests库的安装与使用,演示如何发送GET和POST请求并处理响应;最后概述爬虫构建流程及挑战,帮助你逐步掌握核心技术,畅游数据海洋。
15 3
|
9天前
|
机器学习/深度学习 人工智能 TensorFlow
深入骨髓的解析:Python中神经网络如何学会‘思考’,解锁AI新纪元
【9月更文挑战第11天】随着科技的发展,人工智能(AI)成为推动社会进步的关键力量,而神经网络作为AI的核心,正以其强大的学习和模式识别能力开启AI新纪元。本文将探讨Python中神经网络的工作原理,并通过示例代码展示其“思考”过程。神经网络模仿生物神经系统,通过加权连接传递信息并优化输出。Python凭借其丰富的科学计算库如TensorFlow和PyTorch,成为神经网络研究的首选语言。
12 1
|
12天前
|
存储 JSON API
Python编程:解析HTTP请求返回的JSON数据
使用Python处理HTTP请求和解析JSON数据既直接又高效。`requests`库的简洁性和强大功能使得发送请求、接收和解析响应变得异常简单。以上步骤和示例提供了一个基础的框架,可以根据你的具体需求进行调整和扩展。通过合适的异常处理,你的代码将更加健壮和可靠,为用户提供更加流畅的体验。
36 0
|
19天前
|
Java 缓存 数据库连接
揭秘!Struts 2性能翻倍的秘诀:不可思议的优化技巧大公开
【8月更文挑战第31天】《Struts 2性能优化技巧》介绍了提升Struts 2 Web应用响应速度的关键策略,包括减少配置开销、优化Action处理、合理使用拦截器、精简标签库使用、改进数据访问方式、利用缓存机制以及浏览器与网络层面的优化。通过实施这些技巧,如懒加载配置、异步请求处理、高效数据库连接管理和启用GZIP压缩等,可显著提高应用性能,为用户提供更快的体验。性能优化需根据实际场景持续调整。
44 0
|
19天前
|
数据采集 存储 数据库
Python中实现简单爬虫与数据解析
【8月更文挑战第31天】在数字化时代的浪潮中,数据成为了新的石油。本文将带领读者通过Python编程语言,从零开始构建一个简单的网络爬虫,并展示如何对爬取的数据进行解析和处理。我们将一起探索请求网站、解析HTML以及存储数据的基础知识,让每个人都能成为自己数据故事的讲述者。
|
20天前
|
数据采集 JavaScript 前端开发
Python 爬虫实战:抓取和解析网页数据
【8月更文挑战第31天】本文将引导你通过Python编写一个简单的网络爬虫,从网页中抓取并解析数据。我们将使用requests库获取网页内容,然后利用BeautifulSoup进行解析。通过本教程,你不仅能够学习到如何自动化地从网站收集信息,还能理解数据处理的基本概念。无论你是编程新手还是希望扩展你的技术工具箱,这篇文章都将为你提供有价值的见解。
|
20天前
|
数据采集 存储 JavaScript
构建你的首个Python网络爬虫:抓取、解析与存储数据
【8月更文挑战第31天】在数字时代的浪潮中,数据成为了新的石油。了解如何从互联网的海洋中提取有价值的信息,是每个技术爱好者的必备技能。本文将引导你通过Python编程语言,利用其强大的库支持,一步步构建出你自己的网络爬虫。我们将探索网页请求、内容解析和数据存储等关键环节,并附上代码示例,让你轻松入门网络数据采集的世界。
|
20天前
|
JSON API 数据库
探索FastAPI:不仅仅是一个Python Web框架,更是助力开发者高效构建现代化RESTful API服务的神器——从环境搭建到CRUD应用实战全面解析
【8月更文挑战第31天】FastAPI 是一个基于 Python 3.6+ 类型提示标准的现代 Web 框架,以其高性能、易用性和现代化设计而备受青睐。本文通过示例介绍了 FastAPI 的优势及其在构建高效 Web 应用中的强大功能。首先,通过安装 FastAPI 和 Uvicorn 并创建简单的“Hello, World!”应用入门;接着展示了如何处理路径参数和查询参数,并利用类型提示进行数据验证和转换。
34 0
|
21天前
|
机器学习/深度学习 数据采集 自然语言处理
Python中实现简单的文本情感分析未来触手可及:新技术趋势与应用深度解析
【8月更文挑战第30天】在数字化的今天,理解和分析用户生成的内容对许多行业至关重要。本文将引导读者通过Python编程语言,使用自然语言处理(NLP)技术,构建一个简单的文本情感分析工具。我们将探索如何利用机器学习模型来识别和分类文本数据中的情感倾向,从而为数据分析和决策提供支持。文章将涵盖从数据预处理到模型训练和评估的全过程,旨在为初学者提供一个易于理解且实用的入门指南。
|
22天前
|
机器学习/深度学习 计算机视觉 Python
深度学习项目中在yaml文件中定义配置,以及使用的python的PyYAML库包读取解析yaml配置文件
深度学习项目中在yaml文件中定义配置,以及使用的python的PyYAML库包读取解析yaml配置文件
31 0