【Python-Keras】keras.fit()和keras.fit_generator()的解析与使用

本文涉及的产品
全局流量管理 GTM,标准版 1个月
云解析 DNS,旗舰版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
简介: 本文解析了Keras中的`fit()`和`fit_generator()`方法,解释了两者在训练神经网络模型时的区别和使用场景,其中`fit()`适用于数据集较小且无需数据增强时,而`fit_generator()`适用于大数据集或需要数据增强的情况。

1 作用与区别

作用: 用于训练神经网络模型,两者可以完成相同的任务

区别:
.fit()时使用的整个训练数据集可以放入内存,并没有应用数据增强,就是.fit()无需使用Keras生成器(即无需数据参数)

当我们有一个巨大的数据集可容纳到我们的内存中或需要应用数据扩充时,将使用.fit_generator()。就是需要使用Keras生成器去扩充数据等等操作。

2 解析与使用

2.1 keras.fit()

(1)参数介绍

keras.fit(object, #要训练的模型
    x = NULL, #训练数据。可以是向量,数组或矩阵
    y = NULL, #训练标签。可以是向量,数组或矩阵
    batch_size = NULL, #它可以接受任何整数值或NULL,默认情况下,它将设置为32。它指定否。每个梯度样本数
    epochs = 10,#一个整数,我们要训练模型epochs的数量
      verbose = getOption("keras.fit_verbose", default = 1),#指定详细模式(0 =静音,1 =进度栏,2 = 1每行记录)      
      callbacks = NULL, 
      view_metrics = getOption("keras.view_metrics",
      default = "auto"), 
      validation_split = 0, 
      validation_data = NULL,
      shuffle = TRUE, 
      class_weight = NULL, 
      sample_weight = NULL,
      initial_epoch = 0, 
      steps_per_epoch = NULL, #它指定之前执行的步骤总数
      validation_steps = NULL,
  ...)

(2)举例使用

我们首先输入训练数据(Xtrain)和训练标签(Ytrain)。然后,我们使用Keras允许我们的模型以batch_size为32训练100个epoch。

model.fit(Xtrain, Ytrain, batch_size = 32, epochs = 100)

(3)原理讲解

当我们调用.fit()函数时,它会做一些假设:

  • 整个训练集可以放入计算机的随机存取存储器(RAM)中。
  • 调用模型。fit方法第二次不会重新初始化我们已经训练好的权重,这意味着如果需要,我们实际上可连续调用fit以进行调整。
  • 无需使用Keras生成器(即无需数据参数)
  • 原始数据本身就是用于训练我们的网络的,而我们的原始数据只能放入内存中

2.2 keras.fit_generator()

(1)参数介绍

fit_generator(object, #Keras对象模型
    generator, #生成器,其输出必须是以下形式的列表:
            #  - (inputs, targets)    
            #  - (input, targets, sample_weights)
            # 生成器的单个输出进行单个批处理,因此列表中的所有数组 长度必须等于批次的大小。生成器是期望的
            # 遍历其数据无限。有时,它永远不会返回或退出。
    steps_per_epoch, #它指定从生成器采取的步骤总数
    epochs = 1,
      verbose = getOption("keras.fit_verbose", default = 1),
      callbacks = NULL, 
      view_metrics = getOption("keras.view_metrics",
      default = "auto"), 
      validation_data = NULL, #可以是以下的其中一种
              # - 一个 inputs 和 targets 列表
            # - 一个发生器
              # - inputs, targets, 和sample_weights 列表,可用于在任何时期结束后评估任何模型的损失和度量。
      validation_steps = NULL,#仅当validation_data是生成器时,才此参数可以使用。它指定生成器之前从生成器采取的步骤总数,在每个epoch停止,其值=在数据集中验证数据点的总数/验证batch大小。
      class_weight = NULL, 
      max_queue_size = 10, 
      workers = 1,
      initial_epoch = 0)

(2)举例使用

数据增强 是一种从现有训练数据集中人为创建新数据集进行训练的方法,以利用可用数据量来提高深度学习神经网络的性能。这是一种正则化形式,使我们的模型比以前更好地推广。
在这里,我们使用Keras ImageDataGenerator对象将数据增强应用于图像的随机平移,调整大小,旋转等。每一批新数据都会根据提供给ImageDataGenerator的参数进行随机调整。

#通过训练图像生成器执行数据论证
dataAugmentaion = ImageDataGenerator(rotation_range = 30,
                                    zoom_range = 0.20, 
                                    fill_mode =“ nearest”,
                                    shear_range= 0.20,
                                    horizo​​ntal_flip = True, 
                                    width_shift_range = 0.1,
                                    height_shift_range = 0.1)

#训练模型
model.fit_generator(dataAugmentaion.flow(trainX,trainY,batch_size = 32),
                     validate_data =(testX,testY),
                     steps_per_epoch = len(trainX)// 32,
                     epoch= 10)

网络训练10个epoch,默认batch大小为32。
对于较小和较不复杂的数据集,建议使用keras.fit函数,而在处理实际数据集时,并不是那么简单,因为实际数据集的大小很大,很难放入计算机内存中。
处理这些数据集更具挑战性,处理这些数据集的重要步骤是执行数据扩充,以避免模型的过拟合,并提高模型的泛化能力。

(3)原理解析

当调用.fit_generator()函数时,它会做一些假设:

  • Keras首先调用了生成器函数(dataAugmentaion)
  • 生成器函数为.fit_generator()函数提供了32的batch_size。
  • .fit_generator()函数首先接受一批数据集,然后对其进行反向传播,然后更新模型中的权重。
  • 对于指定的epoch数(在本例中为10),将重复此过程。
目录
相关文章
|
6天前
|
数据采集 供应链 API
Python爬虫与1688图片搜索API接口:深度解析与显著收益
在电子商务领域,数据是驱动业务决策的核心。阿里巴巴旗下的1688平台作为全球领先的B2B市场,提供了丰富的API接口,特别是图片搜索API(`item_search_img`),允许开发者通过上传图片搜索相似商品。本文介绍如何结合Python爬虫技术高效利用该接口,提升搜索效率和用户体验,助力企业实现自动化商品搜索、库存管理优化、竞品监控与定价策略调整等,显著提高运营效率和市场竞争力。
30 3
|
27天前
|
数据采集 JSON API
如何利用Python爬虫淘宝商品详情高级版(item_get_pro)API接口及返回值解析说明
本文介绍了如何利用Python爬虫技术调用淘宝商品详情高级版API接口(item_get_pro),获取商品的详细信息,包括标题、价格、销量等。文章涵盖了环境准备、API权限申请、请求构建和返回值解析等内容,强调了数据获取的合规性和安全性。
|
25天前
|
数据挖掘 vr&ar C++
让UE自动运行Python脚本:实现与实例解析
本文介绍如何配置Unreal Engine(UE)以自动运行Python脚本,提高开发效率。通过安装Python、配置UE环境及使用第三方插件,实现Python与UE的集成。结合蓝图和C++示例,展示自动化任务处理、关卡生成及数据分析等应用场景。
100 5
|
1月前
|
存储 缓存 Python
Python中的装饰器深度解析与实践
在Python的世界里,装饰器如同一位神秘的魔法师,它拥有改变函数行为的能力。本文将揭开装饰器的神秘面纱,通过直观的代码示例,引导你理解其工作原理,并掌握如何在实际项目中灵活运用这一强大的工具。从基础到进阶,我们将一起探索装饰器的魅力所在。
|
1月前
|
Android开发 开发者 Python
通过标签清理微信好友:Python自动化脚本解析
微信已成为日常生活中的重要社交工具,但随着使用时间增长,好友列表可能变得臃肿。本文介绍了一个基于 Python 的自动化脚本,利用 `uiautomator2` 库,通过模拟用户操作实现根据标签批量清理微信好友的功能。脚本包括环境准备、类定义、方法实现等部分,详细解析了如何通过标签筛选并删除好友,适合需要批量管理微信好友的用户。
60 7
|
2月前
|
XML 数据采集 数据格式
Python 爬虫必备杀器,xpath 解析 HTML
【11月更文挑战第17天】XPath 是一种用于在 XML 和 HTML 文档中定位节点的语言,通过路径表达式选取节点或节点集。它不仅适用于 XML,也广泛应用于 HTML 解析。基本语法包括标签名、属性、层级关系等的选择,如 `//p` 选择所有段落标签,`//a[@href='example.com']` 选择特定链接。在 Python 中,常用 lxml 库结合 XPath 进行网页数据抓取,支持高效解析与复杂信息提取。高级技巧涵盖轴的使用和函数应用,如 `contains()` 用于模糊匹配。
|
2月前
|
测试技术 开发者 Python
使用Python解析和分析源代码
本文介绍了如何使用Python的`ast`模块解析和分析Python源代码,包括安装准备、解析源代码、分析抽象语法树(AST)等步骤,展示了通过自定义`NodeVisitor`类遍历AST并提取信息的方法,为代码质量提升和自动化工具开发提供基础。
91 8
|
2月前
|
数据可视化 图形学 Python
在圆的外面画一个正方形:Python实现与技术解析
本文介绍了如何使用Python的`matplotlib`库绘制一个圆,并在其外部绘制一个正方形。通过计算正方形的边长和顶点坐标,实现了圆和正方形的精确对齐。代码示例详细展示了绘制过程,适合初学者学习和实践。
53 9
|
2月前
|
存储 缓存 开发者
Python编程中的装饰器深度解析
本文将深入探讨Python语言的装饰器概念,通过实际代码示例展示如何创建和应用装饰器,并分析其背后的原理和作用。我们将从基础定义出发,逐步引导读者理解装饰器的高级用法,包括带参数的装饰器、多层装饰器以及装饰器与类方法的结合使用。文章旨在帮助初学者掌握这一强大工具,同时为有经验的开发者提供更深层次的理解和应用。
44 7
|
2月前
|
机器学习/深度学习 数据采集 数据挖掘
Python编程语言的魅力:从入门到进阶的全方位解析
Python编程语言的魅力:从入门到进阶的全方位解析