如何用 Python 和 fast.ai 做图像深度迁移学习?

简介: 本文带你认识一个优秀的新深度学习框架,了解深度学习中最重要的3件事。框架看到这个题目,你可能会疑惑:老师,你不是讲过如何用深度学习做图像分类了吗?迁移学习好像也讲过了啊!说得对!我要感谢你对我专栏的持续关注。
img_5f718a6d2b7b4c396c40304feb71e27e.jpe

本文带你认识一个优秀的新深度学习框架,了解深度学习中最重要的3件事。

框架

看到这个题目,你可能会疑惑:

老师,你不是讲过如何用深度学习做图像分类了吗?迁移学习好像也讲过了啊!

说得对!我要感谢你对我专栏的持续关注。我确实讲过深度学习做图像分类,以及迁移学习这两项内容。

写这篇文章,是因为最近因为科研的关系,发现了 fast.ai 这款框架。我希望把它介绍给你。

你可能会不解,之前介绍过的 TuriCreate, Tensorflow, tflearn 和 Keras 好像都挺好用的啊!

我想问问,你在实际的科研工作里,用过哪一个呢?

大多数的读者,只怕基本上都没真正用它们跑过实际的任务。

为什么呢?

因为对普通用户(例如我经常提到的“文科生”),这些框架要么用起来很简单,但是功能不够强大;要么功能很强大,但是不够易用。

例如苹果的 TuriCreate ,我给你演示过,直接零基础上手都没问题。但当你希望对模型进行构造调整的时候,马上就会发现困难重重。因为其专长在于快速产生模型,并且部署到苹果移动设备,因此文档里面底层细节的介绍是有欠缺的。而且有些模型,非苹果平台目前还不能兼容。

img_bb6ea5642e4a451687140f67566050ce.jpe

至于某著名框架,直到推出3年后,在各方压力下,不得已才把好用的 Eager Execution 作为主要使用模式。其间充分体现了那种技术人员独有的傲慢和固执。另外,就连程序员和数据科学家们都把吐槽“看不懂”它的官方文档当作了家常便饭。这些轶事,由于公开发布会招致口水仗,所以我只写在了知识星球专属语雀团队《发现了一套非常棒的(该框架名称)视频教程》一文中。感兴趣的话,不妨去看看。

原本我认为, Keras 已经是把功能和易用性做到了最佳平衡了。直到我看到了 Jeremy Howard,也就是 fast.ai 创始人提出的评判标准——如果一个深度学习框架需要写个教程给你,那它的易用性还不够好。

我看了之后,可以用感动来形容。

Jeremy 说这话,不是为了夸自己——因为他甚至做了个 MOOC 出来。他自己评价,说目前 fast.ai 的易用性依然不算成功。但在我看来, fast.ai 是目前把易用性和功能都做到了极致的深度学习框架。

它的门槛极低。如同 TuriCreate 一样,你可以很轻易用几句话写个图片分类模型出来,人人都能立即上手。

它的天花板又很高。因为它只是个包裹了 Pytorch 的代码库。

你可能也听说了,在过去的一年里,Pytorch 在学术界大放异彩,就是因为它的门槛对于科研人员来说,已经足够友好了。如果你有需求,可以非常方便地通过代码的修改和复用,敏捷构造自己的深度学习模型。

这种积木式的组合方式,使得许多新论文中的模型,可以第一时间被复现验证。如果你在这个过程中有了自己的灵感和心得,可以马上实践。

且慢,fast.ai 的作者不是已经做了自己的 MOOC 了吗?那写这篇文章,岂不是多此一举?

不是的。

首先,作者每年迭代一个 MOOC 的版本,因为 MOOC 一共包括三门课程,分别是:

但现在你能看到的深度学习基础课,还是去年录的。今年10月,伴随着 Pytorch 1.0 的推出, fast.ai 做了一次显著的大版本(1.0)更新。如果你去看去年的课程,会发现和目前的 fast.ai 代码有很多区别。在完成同一个功能时,你愿意再跑去学旧的过时内容吗?特别是,如果搞混了,还很容易出错。

可是,想看到这个版本课程的免费视频,你至少得等到明年1月。因为目前正式学员们也才刚刚开课。

img_25e91cf80982d267ce8a0daaaa5b043a.jpe

而且,那视频,也是英文的。

正因如此,我觉得有必要给你讲讲,如何用最新的 fast.ai 1.0 版本,来完成图像深度迁移学习。

数据

Jeremy 在 MOOC 中提到,如果你打算让机器通过数据来学习,你需要提供3样东西给它,分别是:

  • 数据(Data)
  • 模型结构(Architecture)
  • 损失度量(Loss Metrics)

模型结构,是根据你的具体问题走的。例如说,你需要让机器做图片分类,那么就需要使用卷积神经网络(Convolutional Neural Network)来表征图片上的像素信息构成的特征。如果你需要做自然语言处理,那么就可以使用循环神经网络(Recurrent Neural Network)来捕捉文本或者字符的顺序关联信息。

损失衡量,是指你提供一个标准,衡量机器对某项任务的处理水平。例如说对于分类效果如何,你可以使用交叉熵(Binary Cross Entropy)来评判。这样,机器会尝试最小化损失结果,从而让分类表现越来越好。

至于数据,因为我们这里的任务是做分类。因此需要有标注的训练数据。

我已经把本文需要用到的数据放到了这个 github 项目上。

img_dd655fbf7c615744e6c47634511474e2.jpe

打开其中的 imgs 文件夹,你会看见3个子文件夹,分别对应训练(train),验证(valid)和测试(test)。

打开 train 文件夹看看。

你没猜错,我们用的图片还是哆啦A梦(doraemon)和瓦力(walle)。

img_1f9133148c6b917afd3c94ab5bc0d07e.jpe

因为这样不仅可以保持教程的一惯性,而且也可以保证结果对比的公平。

打开哆啦A梦的目录看看:

img_e0cab07439f46d2f8a4adf1841832234.jpe

展示其中第一个文件内容。

img_6d295bd780aa37e0f687ce591b86f4fa.jpe

好熟悉,是不是?

你可以浏览一下其他的哆啦A梦照片,然后别忘了去瓦力的文件夹里面扫上一眼。

img_253beb44a02b57c4c557a3a85cca3ac1.jpe

这就是我们的数据集了。

环境

为了运行深度学习代码,你需要一个 GPU 。但是你不需要去买一个,租就好了。最方便的租用方法,就是云平台。

fast.ai 官方,给出了以下5种云计算平台使用选项:

其中,我推荐你使用的,是 Google Compute Platform 。原因很简单,首先它成本低,每小时只需要 0.38 美元。更重要的是,如果你是新用户, Google 会先送给你300美金,1年内有效。算算看,这够你运行多久深度学习?

img_713e83a6aeee7a24cb7c17bb7ad458c4.jpe

原先,fast.ai 上面的设置 Google Compute Platform 教程写得很简略。于是我写了个一步步的教程,请使用这个链接访问。

img_2056fc1e0e50d4d20991d3c617db70c9.jpe

不过,我发现 fast.ai 的迭代速度简直惊人,短短几天时间,新的教程就出来了,而且详尽许多。因此你也可以点击这里查看官方的教程。其中如果有跳步,你可以回看我的教程,作为补充。

因此,Google Compute Platform 中间步骤,咱们就不赘述了。当你的终端里面出现这样的提示的时候,就证明一切准备工作都就绪了。

img_c876e73f8738063b9755df2aac5d1e54.jpe

下面,你需要下载刚刚在 github 上面的代码和数据集。

git clone https://github.com/wshuyi/demo-image-classification-fastai.git
img_7aec53e7252a73ac967a709eec1a0ad1.jpe

之后,就可以呼叫 jupyter 出场了。

jupyter lab
img_c02fb43a5a0100ac72dd735662a4f1d7.jpe

注意因为你是在 Google Compute Platform 云端执行 jupyter ,因此浏览器不会自动弹出。

你需要打开 Firefox 或者 Chrome,在其中输入这个链接http://localhost:8080/lab?)。

img_8fc18a73dc35a7aa0de65b48917ab0bc.jpe

打开左侧边栏里面的 demo.ipynb

img_970469979afbee530ee5a47b60f49b84.jpe

本教程全部的代码都在这里了。当然,你如果比较心急,可以选择执行Run->Run All Cells,查看全部运行结果。

img_822e80b2b91bd0e0cc6b6f8c2bf87c0f.jpe

但是,跟之前一样,我还是建议你跟着教程的说明,一步步执行它们。以便更加深刻体会每一条语句的含义。

载入

我们先要载入数据。第一步是从 fast.ai 读入一些相关的功能模块。

from fastai import *
from fastai.vision import *
from fastai.core import *

接着,我们需要设置数据所在文件夹的位置,为 imgs 目录。

img_d04924c0c6fbcf229fa9506abffaa1e3.jpe

执行:

path = Path('imgs')

下面,我们让 fast.ai 帮我们载入全部的数据。这时我们调用 ImageDataBunch 类的 from_folder 函数,结果存储到 data 中:

data = ImageDataBunch.from_folder(path, test='test', ds_tfms=get_transforms(), size=224)

注意这里,我们不仅读入了数据,还顺手做了2件事:

  • 我们进行了数据增强(augmentation),也就是对数据进行了翻转、拉伸、旋转,弄出了很多“新”训练数据。这样做的目的,是因为数据越多,越不容易出现过拟合(over-fitting),也就是模型死记硬背,蒙混考试,却没有抓住真正的规律。
  • 我们把图片大小进行了统一,设置成了 224 x 224 ,这样做的原因,是我们需要使用迁移学习,要用到预训练模型。预训练模型是在这样大小的图片上面训练出来的,因此保持大小一致,效果更好。

下面,检查一下数据载入是否正常:

data.show_batch(rows=3, figsize=(10,10))
img_d1bf126730de7333ac9b4de981d3cbd9.jpe

没问题。图片和标记都是正确的。

训练

用下面这一条语句,我们把“数据”、“模型结构”和“损失度量”三样信息,一起喂给机器。

learn = ConvLearner(data, models.resnet34, metrics=accuracy)

数据就不说了,模型我们采用的是 resnet34 这样一个预训练模型作为基础架构。至于损失度量,我们用的是准确率(accuracy)。

你可能会纳闷,这就完了?不对呀!

没有告诉模型类别有几个啊,没有指定任务迁移之后接续的几个层次的数量、大小、激活函数……

对,不需要。

因为 fast.ai 根据你输入的上述“数据”、“模型结构”和“损失度量”信息,自动帮你把这些闲七杂八的事情默默搞定了。

下面,你需要用一条指令来训练它:

learn.fit_one_cycle(1)

注意,这里我们要求 fast.ai 使用 one cycle policy 。如果你对细节感兴趣,可以点击这个链接了解具体内容。

img_a54d5aecf9ea11f4a771037625fa251c.jpe

5秒钟之后,训练结束。

验证集准确率是,100%。

注意,你“拿来”的这个 resnet34 模型当初做训练的时候,可从来没有见识过哆啦A梦或者瓦力。

看了100多张形态各异,包含各种背景噪声的图片,它居然就能 100% 准确分辨了。

之前我们讲过机器学习的可解释性很重要。没错,fast.ai 也帮我们考虑到了这点。

preds,y = learn.get_preds()
interp = ClassificationInterpretation(data, preds, y, loss_class=nn.CrossEntropyLoss)

执行上面这两行语句,不会有什么输出。但是你手里有了个解释工具。

我们来看看,机器判断得最不好的9张图片都有哪些?

interp.plot_top_losses(9, figsize=(10,10))
img_9c3e801e7ec5d8b165d774f3acd79e15.jpe

因为准确率已经 100% 了,所以单看数值,你根本无法了解机器判断不同照片的时候,遇到了哪些问题。但是这个解释器却可以立即让你明白,哪些图片,机器处理起来,底气(信心)最为不足。

我们还能让解释器做个混淆矩阵出来:

interp.plot_confusion_matrix()
img_eab2cc3e364fae4fa47d14b24d7dd09d.jpe

不过这个混淆矩阵好像没有什么意思。反正全都判断对了。

评估

我们的模型,是不是已经完美了?

不好说。

因为我们刚才展示的,只是验证集的结果。这个验证集,机器在迭代模型参数的时候每一回都拿来尝试。所以要检验最为真实的效能,我们需要让机器看从来没有看到过的图片。

你可以到 test 目录下面,看看都有什么。

img_5590089341cd16dc58ee1da62929f77c.jpe

注意这里一共6张图片,3张哆啦A梦的,3张瓦力的。

这次,我们还会使用刚才用过的 get_preds 函数。不过区别是,我们把 is_test 标记设置为 True,这样机器就不会再去验证集里面取数据了,而是看测试集的。

preds,y = learn.get_preds(is_test=True)

注意目录下面看到的文件顺序,是依据名称排列的。但是 fast.ai 读取数据的时候,其实是做了随机洗牌(randomized shuffling)。我们得看看实际测试集里面的文件顺序。

data.test_dl.dl.dataset.ds.x
img_5d82d76dd6f0dbe84100f21a468eb011.jpe

好了,我们自己心里有数了。下面就看看机器能不能都判断正确了。

preds
img_232309c5fb9e9226559d4d116cd53d0c.jpe

这都啥玩意儿啊?

别着急,这是模型预测时候,根据两个不同的分类,分别给出的倾向数值。数值越大,倾向程度越高。

左侧一列,是哆啦A梦;右侧一列,是瓦力。

我们用 np.argmax 函数,把它简化一些。

np.argmax(preds, axis=1)
img_9285d13b7ed1a58b88217f85fb8e6034.jpe

这样一来,看着就清爽多了。

我们来检查一下啊:瓦力,瓦力,哆啦A梦,哆啦A梦,哆啦A梦,哆啦A梦……

不对呀!

最后这一张,walle.113.jpg,不应该判断成瓦力吗?

打开看看。

img_e4a4a985b9d7b7bb86dcfa5d023c07aa.jpe

哦,难怪。另一个机器人也出现在图片中,圆头圆脑的,确实跟哆啦A梦有相似之处。

要不,就这样了?

微调

那哪儿行?!

我们做任务,要讲究精益求精啊。

遇到错误不要紧,我们尝试改进模型。

用的方法,叫做微调(fine-tuning)。

我们刚刚,不过是移花接木,用了 resnet34 的身体,换上了一个我们自定义的头部层次,用来做哆啦A梦和瓦力的分辨。

img_d00824159fab150bce9f7aa11ccf05dd.jpe

这个训练结果,其实已经很好了。但是既然锁定了“身体”部分的全部参数,只训练头部,依然会遇到判断失误。那我们自然想到的,就应该是连同“身体”,一起调整训练了。

但是这谈何容易?

你调整得动作轻微,那么效果不会明显;如果你调整过了劲儿,“身体”部分的预训练模型通过海量数据积累的参数经验,就会被破坏掉。

两难啊,两难!

好在,聪明的研究者提出了一个巧妙的解决之道。这非常符合我们不只一次提及的“第一性原理”,那就是返回到事情的本源,问出一句:

谁说调整的速度,要全模型都一致?!

深度卷积神经网络,是一个典型的层次模型。

模型靠近输入的地方,捕获的是底层的特征。例如边缘形状等。

模型靠近输出的地方,捕获的是高层特征,例如某种物体的形貌。

对于底层特征,我们相信哆啦A梦、瓦力和原先训练的那些自然界事物,有很多相似之处,因此应该少调整。

反之,原先模型用于捕获猫、狗、兔子的那些特征部分,我们是用不上的,因此越靠近输出位置的层次,我们就应该多调整。

这种不同力度的调整,是通过学习速率(learning rate)来达成的。具体到我们的这种区分,专用名词叫做“歧视性学习速率”(discriminative learning rate)。

你可能想放弃了,这么难!我不玩儿了!

且慢,看看 fast.ai 怎么实现“歧视性学习速率”。

learn.unfreeze()
learn.fit_one_cycle(3, slice(1e-5,3e-4))

对,只需在这里指定一下,底层和上层,选择什么不同的起始速率。搞定。

没错,就是这么不讲道理地智能化

img_6516512e383954122e78ffb70235a416.jpe

这次,训练了3个循环(cycle)。

注意,虽然准确率没有变化(一直是100%,也不可能提升了),但是损失数值,不论是训练集,还是验证集上的,都在减小。

这证明模型在努力地学东西。

你可能会担心:这样会不会导致过拟合啊?

看看就知道了,训练集上的损失数值,一直高于验证集,这就意味着,没有过拟合发生的征兆。

好了,拿着这个微调优化过后的模型,我们再来试试测试集吧。

首先我们强迫症似地看看测试集文件顺序有没有变化:

data.test_dl.dl.dataset.ds.x
img_acb596717a6c051063771e694c3bbf43.jpe

既然没有变,我们就放心了。

下面我们执行预测:

preds,y = learn.get_preds(is_test=True)

然后,观察结果:

np.argmax(preds, axis=1)
img_e67ea53abf669da1749a377061f31a0d.png

如你所见,这次全部判断正确。

可见,我们的微调,是真实有用的。

小结

本文为你介绍了如何用 fast.ai 1.0 框架进行图像深度迁移学习。可以看到, fast.ai 不仅简洁、功能强大,而且足够智能化。所有可以帮用户做的事情,它全都替你代劳。作为研究者,你只需要关注“数据”、“模型结构”和“损失度量”这3个关键问题,以改进学习效果。

我希望你不要满足于把代码跑下来。用你获得的300美金,换上自己的数据跑一跑,看看能否获得足够满意的结果。

祝(深度)学习愉快!

喜欢请点赞和打赏。还可以微信关注和置顶我的公众号“玉树芝兰”(nkwangshuyi)

如果你对 Python 与数据科学感兴趣,不妨阅读我的系列教程索引贴《如何高效入门数据科学?》,里面还有更多的有趣问题及解法。

目录
相关文章
|
12天前
|
机器学习/深度学习 人工智能 算法
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
宠物识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了37种常见的猫狗宠物种类数据集【'阿比西尼亚猫(Abyssinian)', '孟加拉猫(Bengal)', '暹罗猫(Birman)', '孟买猫(Bombay)', '英国短毛猫(British Shorthair)', '埃及猫(Egyptian Mau)', '缅因猫(Maine Coon)', '波斯猫(Persian)', '布偶猫(Ragdoll)', '俄罗斯蓝猫(Russian Blue)', '暹罗猫(Siamese)', '斯芬克斯猫(Sphynx)', '美国斗牛犬
83 29
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
|
5天前
|
人工智能 API 语音技术
TEN Agent:开源的实时多模态 AI 代理框架,支持语音、文本和图像的实时通信交互
TEN Agent 是一个开源的实时多模态 AI 代理框架,集成了 OpenAI Realtime API 和 RTC 技术,支持语音、文本和图像的多模态交互,具备实时通信、模块化设计和多语言支持等功能,适用于智能客服、实时语音助手等多种场景。
70 15
TEN Agent:开源的实时多模态 AI 代理框架,支持语音、文本和图像的实时通信交互
|
8天前
|
机器学习/深度学习 人工智能
SNOOPI:创新 AI 文本到图像生成框架,提升单步扩散模型的效率和性能
SNOOPI是一个创新的AI文本到图像生成框架,通过增强单步扩散模型的指导,显著提升模型性能和控制力。该框架包括PG-SB和NASA两种技术,分别用于增强训练稳定性和整合负面提示。SNOOPI在多个评估指标上超越基线模型,尤其在HPSv2得分达到31.08,成为单步扩散模型的新标杆。
49 10
SNOOPI:创新 AI 文本到图像生成框架,提升单步扩散模型的效率和性能
|
8天前
|
人工智能 搜索推荐 开发者
Aurora:xAI 为 Grok AI 推出新的图像生成模型,xAI Premium 用户可无限制访问
Aurora是xAI为Grok AI助手推出的新图像生成模型,专注于生成高逼真度的图像,特别是在人物和风景图像方面。该模型支持文本到图像的生成,并能处理包括公共人物和版权形象在内的多种图像生成请求。Aurora的可用性因用户等级而异,免费用户每天能生成三张图像,而Premium用户则可享受无限制访问。
46 11
Aurora:xAI 为 Grok AI 推出新的图像生成模型,xAI Premium 用户可无限制访问
|
15天前
|
机器学习/深度学习 人工智能 编解码
OminiControl:AI图像生成框架,实现图像主题控制和空间精确控制
OminiControl 是一个高度通用且参数高效的 AI 图像生成框架,专为扩散变换器模型设计,能够实现图像主题控制和空间精确控制。该框架通过引入极少量的额外参数(0.1%),支持主题驱动控制和空间对齐控制,适用于多种图像生成任务。
61 10
OminiControl:AI图像生成框架,实现图像主题控制和空间精确控制
|
14天前
|
Web App开发 机器学习/深度学习 人工智能
Magic Copy:开源的 AI 抠图工具,在浏览器中自动识别图像进行抠图
Magic Copy 是一款开源的 AI 抠图工具,支持 Chrome 浏览器扩展。它基于 Meta 的 Segment Anything Model 技术,能够自动识别图像中的前景对象并提取出来,简化用户从图片中提取特定元素的过程,提高工作效率。
56 7
Magic Copy:开源的 AI 抠图工具,在浏览器中自动识别图像进行抠图
|
23天前
|
人工智能 自然语言处理 前端开发
VideoChat:高效学习新神器!一键解读音视频内容,结合 AI 生成总结内容、思维导图和智能问答
VideoChat 是一款智能音视频内容解读助手,支持批量上传音视频文件并自动转录为文字。通过 AI 技术,它能快速生成内容总结、详细解读和思维导图,并提供智能对话功能,帮助用户更高效地理解和分析音视频内容。
94 6
VideoChat:高效学习新神器!一键解读音视频内容,结合 AI 生成总结内容、思维导图和智能问答
|
27天前
|
机器学习/深度学习 人工智能 自然语言处理
Documind:开源 AI 文档处理工具,将 PDF 转换为图像提取结构化数据
Documind 是一款利用 AI 技术从 PDF 中提取结构化数据的先进文档处理工具,支持灵活的本地或云端部署。
88 8
Documind:开源 AI 文档处理工具,将 PDF 转换为图像提取结构化数据
|
17小时前
|
机器学习/深度学习 人工智能
Leffa:Meta AI 开源精确控制人物外观和姿势的图像生成框架,在生成穿着的同时保持人物特征
Leffa 是 Meta 开源的图像生成框架,通过引入流场学习在注意力机制中精确控制人物的外观和姿势。该框架不增加额外参数和推理成本,适用于多种扩散模型,展现了良好的模型无关性和泛化能力。
25 11
Leffa:Meta AI 开源精确控制人物外观和姿势的图像生成框架,在生成穿着的同时保持人物特征
|
17天前
|
机器学习/深度学习 人工智能 自然语言处理
AI驱动的个性化学习路径优化
在当前教育领域,个性化学习正逐渐成为一种趋势。本文探讨了如何利用人工智能技术来优化个性化学习路径,提高学习效率和质量。通过分析学生的学习行为、偏好和表现,AI可以动态调整学习内容和难度,实现真正的因材施教。文章还讨论了实施这种技术所面临的挑战和潜在的解决方案。
51 7