TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11(3)https://developer.aliyun.com/article/1427022
前面的方法使用容器来构建栈,该容器添加存储在assets
文件夹中的棋盘图像。 栈的下一个子项是居中对齐的容器,其中所有片段图像都通过对buildChessBoard()
的调用以小部件的形式添加为行和列包装。 整个栈作为子级添加到容器中并返回,以便出现在屏幕上。
此时,应用显示棋盘,以及所有放置在其初始位置的棋子。 如下所示:
现在,让我们使这些棋子变得可移动,以便我们可以玩一个真实的游戏。
使片段移动
在本节中,我们将用可拖动的工具包装每块棋子,以便用户能够将棋子拖动到所需位置。 让我们详细看一下实现:
- 回想一下,我们声明了一个哈希图来存储片段的位置。 移动将包括从一个盒子中移出一块并将其放在另一个盒子中。 假设我们有两个变量
'from'
和'to'
,它们存储用于移动片段的盒子的索引。 进行移动后,我们拿起'from'
处的片段并将其放入'to'
中。 因此,'from'
的框变为空。 按照相同的逻辑,我们将定义refreshBoard()
方法,该方法在每次移动时都会调用:
void refreshBoard(String from, String to) { setState(() { board[to] = board[from]; board[from] = " "; }); }
from
和to
变量存储源和目标正方形的索引。 这些值在board
HasMhap 中用作键。 进行移动时,from
处的棋子会移至to.
。此后,from
处的方块应该变空。 它包含在setState()
中,以确保每次移动后都更新 UI。
- 现在,让我们将其拖曳。 为此,我们将拖动项附加到
getPieceImage()
方法返回的木板的每个图像小部件上。 我们通过修改方法来做到这一点:
Widget getImage(String squareName) { return Expanded( child: DragTarget<List>(builder: (context, accepted, rejected) { return Draggable<List>( child: mapImages(squareName), feedback: mapImages(squareName), onDragCompleted: () {}, data: [ squareName, ], ); }, onWillAccept: (willAccept) { return true; }, onAccept: (List moveInfo) { String from = moveInfo[0]; String to = squareName; refreshBoard(from, to); }) ); }
在前面的函数中,我们首先将特定正方形的图像包装在Draggable
中。 此类用于感测和跟随屏幕上的拖动手势。 child
属性用于指定要拖动的窗口小部件,而反馈内部的窗口小部件用于跟踪手指在屏幕上的移动。 当拖动完成并且用户抬起手指时,目标将有机会接受所携带的数据。 由于我们正在源和目标之间移动,因此我们将添加Draggable
作为DragTarget
的子代,以便可以在源和目标之间移动小部件。 onWillAccept
设置为true
,以便可以进行所有移动。
可以修改此属性,以使其具有可以区分合法象棋动作并且不允许拖动非法动作的功能。 放下片段并完成拖动后,将调用onAccept
。 moveInfo
列表保存有关拖动源的信息。 在这里,我们调用refreshBoard()
,并传入from
和to
的值,以便屏幕可以反映运动。 至此,我们完成了向用户显示初始棋盘的操作,并使棋子可以在盒子之间移动。
在下一节中,我们将通过对托管的国际象棋服务器进行 API 调用来增加应用的交互性。 这些将使游戏栩栩如生。
将国际象棋引擎 API 与 UI 集成
托管的棋牌服务器将作为对手玩家添加到应用中。 用户将是白色的一面,而服务器将是黑色的一面。 这里要实现的游戏逻辑非常简单。 第一步是提供给应用用户。 用户进行移动时,他们将棋盘的状态从状态 X 更改为状态 Y。棋盘的状态由 FEN 字符串表示。 同样,他们将一块from
移到一个特定的正方形to
移到一个特定的正方形,这有助于他们的移动。 当用户完成移动时,状态 X 的 FEN 字符串及其当前移动(通过将from
和to
正方形连接在一起而获得)以POST
请求的形式发送到服务器。 作为回报,服务器从其侧面进行下一步移动,然后将其反映在 UI 上。
让我们看一下此逻辑的代码:
- 首先,我们定义一个名为
getPositionString()
的方法来为应用的特定状态生成 FEN 字符串:
String getPositionString(String move) { String s = ""; for(int i = 8; i >= 1; i--) { int count = 0; for(int j = 97; j <= 104; j++) { String ch = String.fromCharCode(j)+'$i'; if(board[ch] == " ") { count += 1; if(j == 104) s = s + "$count"; } else { if(count > 0) s = s + "$count"; s = s + board[ch];count = 0; } } s = s + "/"; } String position = s.substring(0, s.length-1) + " w KQkq - 0 1"; var json = jsonEncode({"position": position, "moves": move}); }
在前面的方法中,我们将move
作为参数,它是from
和to
变量的连接。 接下来,我们为棋盘的当前状态创建 FEN 字符串。 创建 FEN 字符串背后的逻辑是,我们遍历电路板的每一行并为该行创建一个字符串。 然后将生成的字符串连接到最终字符串。
让我们借助示例更好地理解这一点。 考虑一个rnbqkbnr/pp1ppppp/8/1p6/8/3P4/PPP1PPPP/RNBQKBNR w KQkq - 0 1
的 FEN 字符串。 在此,每行可以用八个或更少的字符表示。 特定行的状态通过使用分隔符“/”与另一行分开。 对于特定的行,每件作品均以其指定的符号表示,其中P
表示白兵,b
表示黑相。 每个占用的正方形均由件符号明确表示。 例如,PpkB
指示板上的前四个正方形被白色棋子,黑色棋子,黑色国王和白色主教占据。 对于空盒子,使用整数,该数字表示可传染的空盒子的数量。 注意示例 FEN 字符串中的8
。 这表示该行的所有 8 个正方形均为空。 3P4
表示前三个正方形为空,第四个方框被白色棋子占据,并且四个正方形为空。
在getPositionString()
方法中,我们迭代从 8 到 1 的每一行,并为每行生成一个状态字符串。 对于每个非空框,我们只需在's'
变量中添加一个表示该块的字符。 对于每个空框,当找到非空框或到达行末时,我们将count
的值增加 1 并将其连接到's'
字符串。 遍历每一行后,我们添加“/”以分隔两行。 最后,我们通过将生成的's'
字符串与w KQkq - 0 1
连接来生成位置字符串。 然后,我们通过将jsonEncode()
与键值对结合使用来生成所需的 JSON 对象
- 我们使用“步骤 1”的“步骤 1”中的
from
和to
变量来保存用户的当前移动。 我们可以通过在refreshBoard()
方法中添加两行来实现:
void refreshBoard(String from, String to) { String move= from + to; getPositionString(move); ..... }
在前面的代码片段中,我们将from
和to
的值连接起来,并将它们存储在名为move
的字符串变量中。 然后,我们调用getPositionString()
,并将move
的值传递给参数。
- 接下来,我们使用在上一步中
makePOSTRequest()
方法中生成的JSON
向服务器发出POST
请求:
void makePOSTRequest(var json) async{ var url = 'http://35.200.253.0:8080/play'; var response = await http.post(url, headers: {"Content-Type": "application/json"} ,body: json); String rsp = response.body; String from = rsp.substring(0,3); String to = rsp.substring(3); }
首先,将国际象棋服务器的 IP 地址存储在url
变量中。 然后,我们使用http.post()
发出HTTP POST
请求,并为 URL,标头和正文传递正确的值。 POST 请求的响应包含服务器端的下一个动作,并存储在变量响应中。 我们解析响应的主体并将其存储在名为rsp
的字符串变量中。 响应基本上是一个字符串,是服务器端的源方和目标方的连接。 例如,响应字符串f4a3
表示国际象棋引擎希望将棋子以f4
正方形移动到a3
正方形。 我们使用substring()
分隔源和目标,并将值存储在from
和to
变量中。
- 现在,通过将调用添加到
makePOSTrequest()
来从getPositionString()
发出 POST 请求:
String getPositionString(String move) { ..... makePOSTRequest(json); }
在 FEN 字符串生成板的给定状态之后,对makePOSTrequest()
的调用添加在函数的最后。
- 最后,我们使用
refreshBoardFromServer()
方法刷新板以反映服务器在板上的移动:
void refreshBoardFromServer(String from, String to) { setState(() { board[to] = board[from]; board[from] = " "; }); }
前述方法中的逻辑非常简单。 首先,我们将映射到from
索引正方形的片段移动到to
索引正方形,然后清空from
索引正方形。
- 最后,我们调用适当的方法以用最新的动作更新 UI:
void makePOSTRequest(var json) async{ ...... refreshBoardFromServer(from, to); buildChessBoard(); }
发布请求成功完成后,我们收到了服务器的响应,我们将调用refreshBoardFromServer()
以更新板上的映射。 最后,我们调用buildChessBoard()
以在应用屏幕上反映国际象棋引擎所做的最新动作。
以下屏幕快照显示了国际象棋引擎进行移动后的更新的用户界面:
请注意,黑色的块在白色的块之后移动。 这就是代码的工作方式。 首先,用户采取行动。 它以板的初始状态发送到服务器。 然后,服务器以其移动进行响应,更新 UI。 作为练习,您可以尝试实现一些逻辑以区分有效动作和无效动作。
可以在这个页面中找到此代码。
现在,让我们通过创建材质应用来包装应用。
创建材质应用
现在,我们将在main.dart
中创建最终的材质应用。 让我们从以下步骤开始:
- 首先,我们创建无状态窗口小部件
MyApp
,并覆盖其build()
方法,如下所示:
class MyApp extends StatelessWidget { @override Widget build(BuildContext context) { return MaterialApp( title: 'Chess', theme: ThemeData(primarySwatch: Colors.blue,), home: MyHomePage(title: 'Chess'), ); } }
- 我们创建一个单独的
StatefulWidget
,称为MyHomePage
,以便将 UI 放置在屏幕中央。MyHomePage
的build()
方法如下所示:
@override Widget build(BuildContext context) { return Scaffold( appBar: AppBar(title: Text('Chess'),), body: Center( child: Column( mainAxisAlignment: MainAxisAlignment.center, children: <Widget>[ChessGame() ], ), ), ); }
- 最后,我们通过在
main.dart
中添加以下行来执行整个代码:
void main() => runApp(MyApp());
而已! 现在,我们有一个交互式的国际象棋游戏应用,您可以与聪明的对手一起玩。 希望你赢!
整个文件的代码可以在这个页面中找到。
总结
在此项目中,我们介绍了强化学习的概念以及为什么强化学习在创建游戏性 AI 的开发人员中很受欢迎。 我们讨论了 Google DeepMind 的 AlphaGo 及其兄弟项目,并深入研究了它们的工作算法。 接下来,我们创建了一个类似的程序来玩 Connect 4,然后下棋。 我们将基于 AI 的国际象棋引擎作为 API 部署到 GPU 实例的 GCP 上,并将其与基于 Flutter 的应用集成。 我们还了解了如何使用 UCI 促进国际象棋的无状态游戏。 完成此项目后,您将对如何将游戏转换为强化学习环境,如何以编程方式定义游戏规则以及如何创建用于玩这些游戏的自学智能体有很好的了解。
在下一章中,我们将创建一个应用,该应用可以使低分辨率图像变成非常高分辨率的图像。 我们将在 AI 的帮助下进行此操作。
八、深度神经网络
在本章中,我们将回顾机器学习,深度神经网络中最先进的技术,也是研究最多的领域之一。
深度神经网络定义
这是一个新闻技术领域蓬勃发展的领域,每天我们都听到成功地将 DNN 用于解决新问题的实验,例如计算机视觉,自动驾驶,语音和文本理解等。
在前几章中,我们使用了与 DNN 相关的技术,尤其是在涉及卷积神经网络的技术中。
出于实际原因,我们将指深度学习和深度神经网络,即其中层数明显优于几个相似层的架构,我们将指代具有数十个层的神经网络架构,或者复杂结构的组合。
穿越时空的深度网络架构
在本节中,我们将回顾从 LeNet5 开始在整个深度学习历史中出现的里程碑架构。
LeNet 5
在 1980 年代和 1990 年代,神经网络领域一直保持沉默。 尽管付出了一些努力,但是架构非常简单,并且需要大的(通常是不可用的)机器力量来尝试更复杂的方法。
1998 年左右,在贝尔实验室中,在围绕手写校验数字分类的研究中,Ian LeCun 开始了一种新趋势,该趋势实现了所谓的“深度学习——卷积神经网络”的基础,我们已经在第 5 章,简单的前馈神经网络中对其进行了研究。
在那些年里,SVM 和其他更严格定义的技术被用来解决这类问题,但是有关 CNN 的基础论文表明,与当时的现有方法相比,神经网络的表现可以与之媲美或更好。
Alexnet
经过几年的中断(即使 LeCun 继续将其网络应用到其他任务,例如人脸和物体识别),可用结构化数据和原始处理能力的指数增长,使团队得以增长和调整模型, 在某种程度上被认为是不可能的,因此可以增加模型的复杂性,而无需等待数月的训练。
来自许多技术公司和大学的计算机研究团队开始竞争一些非常艰巨的任务,包括图像识别。 对于以下挑战之一,即 Imagenet 分类挑战,开发了 Alexnet 架构:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lULpcW1A-1681785128423)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00125.jpg)]
Alexnet 架构
主要功能
从其第一层具有卷积运算的意义上讲,Alexnet 可以看作是增强的 LeNet5。 但要添加未使用过的最大池化层,然后添加一系列密集的连接层,以建立最后的输出类别概率层。 视觉几何组(VGG)模型
图像分类挑战的其他主要竞争者之一是牛津大学的 VGG。
VGG 网络架构的主要特征是它们将卷积滤波器的大小减小到一个简单的3x3
,并按顺序组合它们。
微小的卷积内核的想法破坏了 LeNet 及其后继者 Alexnet 的最初想法,后者最初使用的过滤器高达11x11
过滤器,但复杂得多且表现低下。 过滤器大小的这种变化是当前趋势的开始:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-r8DheOZh-1681785128423)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00126.jpg)]
VGG 中每层的参数编号摘要
然而,使用一系列小的卷积权重的积极变化,总的设置是相当数量的参数(数以百万计的数量级),因此它必须受到许多措施的限制。
原始的初始模型
在由 Alexnet 和 VGG 主导的两个主要研究周期之后,Google 凭借非常强大的架构 Inception 打破了挑战,该架构具有多次迭代。
这些迭代的第一个迭代是从其自己的基于卷积神经网络层的架构版本(称为 GoogLeNet)开始的,该架构的名称让人想起了始于网络的方法。
GoogLenet(InceptionV1)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ofFPZuno-1681785128424)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00127.jpg)]
InceptionV1
GoogLeNet 是这项工作的第一个迭代,如下图所示,它具有非常深的架构,但是它具有九个链式初始模块的令人毛骨悚然的总和,几乎没有或根本没有修改:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-EPjLvndu-1681785128424)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00128.jpg)]
盗梦空间原始架构
与两年前发布的 Alexnet 相比,它是如此复杂,但它设法减少了所需的参数数量并提高了准确率。
但是,由于几乎所有结构都由相同原始结构层构建块的确定排列和重复组成,因此提高了此复杂架构的理解和可伸缩性。
批量归一化初始化(V2)
2015 年最先进的神经网络在提高迭代效率的同时,还存在训练不稳定的问题。
为了理解问题的构成,首先我们将记住在前面的示例中应用的简单正则化步骤。 它主要包括将这些值以零为中心,然后除以最大值或标准偏差,以便为反向传播的梯度提供良好的基线。
在训练非常大的数据集的过程中,发生的事情是,经过大量训练示例之后,不同的值振荡开始放大平均参数值,就像在共振现象中一样。 我们非常简单地描述的被称为协方差平移。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sz6uJZfR-1681785128424)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00129.jpg)]
有和没有批量归一化的表现比较
这是开发批归一化技术的主要原因。
再次简化了过程描述,它不仅包括对原始输入值进行归一化,还对每一层上的输出值进行了归一化,避免了在层之间出现不稳定性之前就开始影响或漂移这些值。
这是 Google 在 2015 年 2 月发布的改进版 GoogLeNet 实现中提供的主要功能,也称为 InceptionV2。
InceptionV3
快进到 2015 年 12 月,Inception 架构有了新的迭代。 两次发行之间月份的不同使我们对新迭代的开发速度有了一个想法。
此架构的基本修改如下:
- 将卷积数减少到最大
3x3
- 增加网络的总体深度
- 在每一层使用宽度扩展技术来改善特征组合
下图说明了如何解释改进的启动模块:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-6JVkvOHu-1681785128424)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00130.jpg)]
InceptionV3 基本模块
这是整个 V3 架构的表示形式,其中包含通用构建模块的许多实例:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-doHC5UCK-1681785128424)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00131.jpg)]
InceptionV3 总体图
残差网络(ResNet)
残差网络架构于 2015 年 12 月出现(与 InceptionV3 几乎同时出现),它带来了一个简单而新颖的想法:不仅使用每个构成层的输出,还将该层的输出与原始输入结合。
在下图中,我们观察到 ResNet 模块之一的简化视图。 它清楚地显示了卷积层栈末尾的求和运算,以及最终的 relu 运算:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lrxuW1RM-1681785128425)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00132.jpg)]
ResNet 一般架构
模块的卷积部分包括将特征从 256 个值减少到 64 个值,一个保留特征数的3x3
过滤层以及一个从 64 x 256 个值增加1x1
层的特征。 在最近的发展中,ResNet 的使用深度还不到 30 层,分布广泛。
其他深度神经网络架构
最近开发了很多神经网络架构。 实际上,这个领域是如此活跃,以至于我们每年或多或少都有新的杰出架构外观。 最有前途的神经网络架构的列表是:
- SqueezeNet:此架构旨在减少 Alexnet 的参数数量和复杂性,声称减少了 50 倍的参数数量
- 高效神经网络(Enet):旨在构建更简单,低延迟的浮点运算数量,具有实时结果的神经网络
- Fractalnet:它的主要特征是非常深的网络的实现,不需要残留的架构,将结构布局组织为截断的分形
示例 – 风格绘画 – VGG 风格迁移
在此示例中,我们将配合 Leon Gatys 的论文《艺术风格的神经算法》的实现。
注意
此练习的原始代码由 Anish Athalye 提供。
我们必须注意,此练习没有训练内容。 我们将仅加载由 VLFeat 提供的预训练系数矩阵,该矩阵是预训练模型的数据库,可用于处理模型,从而避免了通常需要大量计算的训练:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-tTvYAxhb-1681785128425)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00133.jpg)]
风格迁移主要概念
有用的库和方法
- 使用
scipy.io.loadmat
加载参数文件
- 我们将使用的第一个有用的库是
scipy.io
模块,用于加载系数数据,该数据另存为 matlab 的 MAT 格式。
- 上一个参数的用法:
scipy.io.loadmat(file_name, mdict=None, appendmat=True, **kwargs)
- 返回前一个参数:
mat_dict : dict :dictionary
,变量名作为键,加载的矩阵作为值。 如果填充了mdict
参数,则将结果分配给它。
数据集说明和加载
为了解决这个问题,我们将使用预训练的数据集,即 VGG 神经网络的再训练系数和 Imagenet 数据集。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ht3GTlIo-1681785128425)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00134.jpg)]
数据集预处理
假设系数是在加载的参数矩阵中给出的,那么关于初始数据集的工作就不多了。
模型架构
模型架构主要分为两部分:风格和内容。
为了生成最终图像,使用了没有最终完全连接层的 VGG 网络。
损失函数
该架构定义了两个不同的损失函数来优化最终图像的两个不同方面,一个用于内容,另一个用于风格。
内容损失函数
content_loss
函数的代码如下:
# content loss content_loss = content_weight * (2 * tf.nn.l2_loss( net[CONTENT_LAYER] - content_features[CONTENT_LAYER]) / content_features[CONTENT_LAYER].size)
风格损失函数
损失优化循环
损耗优化循环的代码如下:
best_loss = float('inf') best = None with tf.Session() as sess: sess.run(tf.initialize_all_variables()) for i in range(iterations): last_step = (i == iterations - 1) print_progress(i, last=last_step) train_step.run() if (checkpoint_iterations and i % checkpoint_iterations == 0) or last_step: this_loss = loss.eval() if this_loss < best_loss: best_loss = this_loss best = image.eval() yield ( (None if last_step else i), vgg.unprocess(best.reshape(shape[1:]), mean_pixel) )
收敛性测试
在此示例中,我们将仅检查指示的迭代次数(迭代参数)。
程序执行
为了以良好的迭代次数(大约 1000 个)执行该程序,我们建议至少有 8GB 的 RAM 内存可用:
python neural_style.py --content examples/2-content.jpg --styles examples/2-style1.jpg --checkpoint-iterations=100 --iterations=1000 --checkpoint-output=out%s.jpg --output=outfinal
前面命令的结果如下:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-IISmOKsl-1681785128425)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/00135.jpg)]
风格迁移步骤
控制台输出如下:
Iteration 1/1000 Iteration 2/1000 Iteration 3/1000 Iteration 4/1000 ... Iteration 999/1000 Iteration 1000/1000 content loss: 908786 style loss: 261789 tv loss: 25639.9 total loss: 1.19621e+06
完整源代码
neural_style.py
的代码如下:
import os import numpy as np import scipy.misc from stylize import stylize import math from argparse import ArgumentParser # default arguments CONTENT_WEIGHT = 5e0 STYLE_WEIGHT = 1e2 TV_WEIGHT = 1e2 LEARNING_RATE = 1e1 STYLE_SCALE = 1.0 ITERATIONS = 100 VGG_PATH = 'imagenet-vgg-verydeep-19.mat' def build_parser(): parser = ArgumentParser() parser.add_argument('--content', dest='content', help='content image', metavar='CONTENT', required=True) parser.add_argument('--styles', dest='styles', nargs='+', help='one or more style images', metavar='STYLE', required=True) parser.add_argument('--output', dest='output', help='output path', metavar='OUTPUT', required=True) parser.add_argument('--checkpoint-output', dest='checkpoint_output', help='checkpoint output format', metavar='OUTPUT') parser.add_argument('--iterations', type=int, dest='iterations', help='iterations (default %(default)s)', metavar='ITERATIONS', default=ITERATIONS) parser.add_argument('--width', type=int, dest='width', help='output width', metavar='WIDTH') parser.add_argument('--style-scales', type=float, dest='style_scales', nargs='+', help='one or more style scales', metavar='STYLE_SCALE') parser.add_argument('--network', dest='network', help='path to network parameters (default %(default)s)', metavar='VGG_PATH', default=VGG_PATH) parser.add_argument('--content-weight', type=float, dest='content_weight', help='content weight (default %(default)s)', metavar='CONTENT_WEIGHT', default=CONTENT_WEIGHT) parser.add_argument('--style-weight', type=float, dest='style_weight', help='style weight (default %(default)s)', metavar='STYLE_WEIGHT', default=STYLE_WEIGHT) parser.add_argument('--style-blend-weights', type=float, dest='style_blend_weights', help='style blending weights', nargs='+', metavar='STYLE_BLEND_WEIGHT') parser.add_argument('--tv-weight', type=float, dest='tv_weight', help='total variation regularization weight (default %(default)s)', metavar='TV_WEIGHT', default=TV_WEIGHT) parser.add_argument('--learning-rate', type=float, dest='learning_rate', help='learning rate (default %(default)s)', metavar='LEARNING_RATE', default=LEARNING_RATE) parser.add_argument('--initial', dest='initial', help='initial image', metavar='INITIAL') parser.add_argument('--print-iterations', type=int, dest='print_iterations', help='statistics printing frequency', metavar='PRINT_ITERATIONS') parser.add_argument('--checkpoint-iterations', type=int, dest='checkpoint_iterations', help='checkpoint frequency', metavar='CHECKPOINT_ITERATIONS') return parser def main(): parser = build_parser() options = parser.parse_args() if not os.path.isfile(options.network): parser.error("Network %s does not exist. (Did you forget to download it?)" % options.network) content_image = imread(options.content) style_images = [imread(style) for style in options.styles] width = options.width if width is not None: new_shape = (int(math.floor(float(content_image.shape[0]) / content_image.shape[1] * width)), width) content_image = scipy.misc.imresize(content_image, new_shape) target_shape = content_image.shape for i in range(len(style_images)): style_scale = STYLE_SCALE if options.style_scales is not None: style_scale = options.style_scales[i] style_images[i] = scipy.misc.imresize(style_images[i], style_scale * target_shape[1] / style_images[i].shape[1]) style_blend_weights = options.style_blend_weights if style_blend_weights is None: # default is equal weights style_blend_weights = [1.0/len(style_images) for _ in style_images] else: total_blend_weight = sum(style_blend_weights) style_blend_weights = [weight/total_blend_weight for weight in style_blend_weights] initial = options.initial if initial is not None: initial = scipy.misc.imresize(imread(initial), content_image.shape[:2]) if options.checkpoint_output and "%s" not in options.checkpoint_output: parser.error("To save intermediate images, the checkpoint output " "parameter must contain `%s` (e.g. `foo%s.jpg`)") for iteration, image in stylize( network=options.network, initial=initial, content=content_image, styles=style_images, iterations=options.iterations, content_weight=options.content_weight, style_weight=options.style_weight, style_blend_weights=style_blend_weights, tv_weight=options.tv_weight, learning_rate=options.learning_rate, print_iterations=options.print_iterations, checkpoint_iterations=options.checkpoint_iterations ): output_file = None if iteration is not None: if options.checkpoint_output: output_file = options.checkpoint_output % iteration else: output_file = options.output if output_file: imsave(output_file, image) def imread(path): return scipy.misc.imread(path).astype(np.float) def imsave(path, img): img = np.clip(img, 0, 255).astype(np.uint8) scipy.misc.imsave(path, img) if __name__ == '__main__': main()
Stilize.py
的代码如下:
import vgg import tensorflow as tf import numpy as np from sys import stderr CONTENT_LAYER = 'relu4_2' STYLE_LAYERS = ('relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1') try: reduce except NameError: from functools import reduce def stylize(network, initial, content, styles, iterations, content_weight, style_weight, style_blend_weights, tv_weight, learning_rate, print_iterations=None, checkpoint_iterations=None): """ Stylize images. This function yields tuples (iteration, image); `iteration` is None if this is the final image (the last iteration). Other tuples are yielded every `checkpoint_iterations` iterations. :rtype: iterator[tuple[int|None,image]] """ shape = (1,) + content.shape style_shapes = [(1,) + style.shape for style in styles] content_features = {} style_features = [{} for _ in styles] # compute content features in feedforward mode g = tf.Graph() with g.as_default(), g.device('/cpu:0'), tf.Session() as sess: image = tf.placeholder('float', shape=shape) net, mean_pixel = vgg.net(network, image) content_pre = np.array([vgg.preprocess(content, mean_pixel)]) content_features[CONTENT_LAYER] = net[CONTENT_LAYER].eval( feed_dict={image: content_pre}) # compute style features in feedforward mode for i in range(len(styles)): g = tf.Graph() with g.as_default(), g.device('/cpu:0'), tf.Session() as sess: image = tf.placeholder('float', shape=style_shapes[i]) net, _ = vgg.net(network, image) style_pre = np.array([vgg.preprocess(styles[i], mean_pixel)]) for layer in STYLE_LAYERS: features = net[layer].eval(feed_dict={image: style_pre}) features = np.reshape(features, (-1, features.shape[3])) gram = np.matmul(features.T, features) / features.size style_features[i][layer] = gram # make stylized image using backpropogation with tf.Graph().as_default(): if initial is None: noise = np.random.normal(size=shape, scale=np.std(content) * 0.1) initial = tf.random_normal(shape) * 0.256 else: initial = np.array([vgg.preprocess(initial, mean_pixel)]) initial = initial.astype('float32') image = tf.Variable(initial) net, _ = vgg.net(network, image) # content loss content_loss = content_weight * (2 * tf.nn.l2_loss( net[CONTENT_LAYER] - content_features[CONTENT_LAYER]) / content_features[CONTENT_LAYER].size) # style loss style_loss = 0 for i in range(len(styles)): style_losses = [] for style_layer in STYLE_LAYERS: layer = net[style_layer] _, height, width, number = map(lambda i: i.value, layer.get_shape()) size = height * width * number feats = tf.reshape(layer, (-1, number)) gram = tf.matmul(tf.transpose(feats), feats) / size style_gram = style_features[i][style_layer] style_losses.append(2 * tf.nn.l2_loss(gram - style_gram) / style_gram.size) style_loss += style_weight * style_blend_weights[i] * reduce(tf.add, style_losses) # total variation denoising tv_y_size = _tensor_size(image[:,1:,:,:]) tv_x_size = _tensor_size(image[:,:,1:,:]) tv_loss = tv_weight * 2 * ( (tf.nn.l2_loss(image[:,1:,:,:] - image[:,:shape[1]-1,:,:]) / tv_y_size) + (tf.nn.l2_loss(image[:,:,1:,:] - image[:,:,:shape[2]-1,:]) / tv_x_size)) # overall loss loss = content_loss + style_loss + tv_loss # optimizer setup train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss) def print_progress(i, last=False): stderr.write('Iteration %d/%d\n' % (i + 1, iterations)) if last or (print_iterations and i % print_iterations == 0): stderr.write(' content loss: %g\n' % content_loss.eval()) stderr.write(' style loss: %g\n' % style_loss.eval()) stderr.write(' tv loss: %g\n' % tv_loss.eval()) stderr.write(' total loss: %g\n' % loss.eval()) # optimization best_loss = float('inf') best = None with tf.Session() as sess: sess.run(tf.initialize_all_variables()) for i in range(iterations): last_step = (i == iterations - 1) print_progress(i, last=last_step) train_step.run() if (checkpoint_iterations and i % checkpoint_iterations == 0) or last_step: this_loss = loss.eval() if this_loss < best_loss: best_loss = this_loss best = image.eval() yield ( (None if last_step else i), vgg.unprocess(best.reshape(shape[1:]), mean_pixel) ) def _tensor_size(tensor): from operator import mul return reduce(mul, (d.value for d in tensor.get_shape()), 1) vgg.py import tensorflow as tf import numpy as np import scipy.io def net(data_path, input_image): layers = ( 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4' ) data = scipy.io.loadmat(data_path) mean = data['normalization'][0][0][0] mean_pixel = np.mean(mean, axis=(0, 1)) weights = data['layers'][0] net = {} current = input_image for i, name in enumerate(layers): kind = name[:4] if kind == 'conv': kernels, bias = weights[i][0][0][0][0] # matconvnet: weights are [width, height, in_channels, out_channels] # tensorflow: weights are [height, width, in_channels, out_channels] kernels = np.transpose(kernels, (1, 0, 2, 3)) bias = bias.reshape(-1) current = _conv_layer(current, kernels, bias) elif kind == 'relu': current = tf.nn.relu(current) elif kind == 'pool': current = _pool_layer(current) net[name] = current assert len(net) == len(layers) return net, mean_pixel def _conv_layer(input, weights, bias): conv = tf.nn.conv2d(input, tf.constant(weights), strides=(1, 1, 1, 1), padding='SAME') return tf.nn.bias_add(conv, bias) def _pool_layer(input): return tf.nn.max_pool(input, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), padding='SAME') def preprocess(image, mean_pixel): return image - mean_pixel def unprocess(image, mean_pixel): return image + mean_pixel
总结
在本章中,我们一直在学习不同的深度神经网络架构。
我们了解了如何构建近年来最著名的架构之一 VGG,以及如何使用它来生成可转换艺术风格的图像。
在下一章中,我们将使用机器学习中最有用的技术之一:图形处理单元。 我们将回顾安装具有 GPU 支持的 TensorFlow 所需的步骤并对其进行训练,并将执行时间与唯一运行的模型 CPU 进行比较。
九、构建图像超分辨率应用
还记得上次和亲人一起旅行并拍了一些漂亮的照片作为记忆,但是当您回到家并刷过它们时,您发现它们非常模糊且质量低下吗? 现在,您剩下的所有美好时光就是您自己的心理记忆和那些模糊的照片。 如果可以使您的照片清晰透明并且可以看到其中的每个细节,那不是很好吗?
超分辨率是基于像素信息的近似将低分辨率图像转换为高分辨率图像的过程。 虽然今天可能还不完全是神奇的,但当技术发展到足以成为通用 AI 应用时,它肯定会在将来挽救生命。
在此项目中,我们将构建一个应用,该应用使用托管在 DigitalOcean Droplet 上的深度学习模型,该模型可以同时比较低分辨率和高分辨率图像,从而使我们更好地了解今天的技术。 我们将使用生成对抗网络(GAN)生成超分辨率图像。
在本章中,我们将介绍以下主题:
- 基本项目架构
- 了解 GAN
- 了解图像超分辨率的工作原理
- 创建 TensorFlow 模型以实现超分辨率
- 构建应用的 UI
- 从设备的本地存储中获取图片
- 在 DigitalOcean 上托管 TensorFlow 模型
- 在 Flutter 上集成托管的自定义模型
- 创建材质应用
让我们从了解项目的架构开始。
基本项目架构
让我们从了解项目的架构开始。
我们将在本章中构建的项目主要分为两个部分:
- Jupyter 笔记本,它创建执行超分辨率的模型。
- 使用该模型的 Flutter 应用,在 Jupyter 笔记本上接受训练后,将托管在 DigitalOcean 中的 Droplet 中。
从鸟瞰图可以用下图描述该项目:
将低分辨率图像放入模型中,该模型是从 Firebase 上托管的 ML Kit 实例中获取的,并放入 Flutter 应用中。 生成输出并将其作为高分辨率图像显示给用户。 该模型缓存在设备上,并且仅在开发人员更新模型时才更新,因此可以通过减少网络延迟来加快预测速度。
现在,让我们尝试更深入地了解 GAN。
了解 GAN
Ian Goodfellow,Yoshua Bengio 和其他人在 NeurIPS 2014 中引入的 GAN 席卷全球。 可以应用于各种领域的 GAN 会根据模型对实际数据样本的学习近似,生成新的内容或序列。 GAN 已被大量用于生成音乐和艺术的新样本,例如下图所示的面孔,而训练数据集中不存在这些面孔:
经过 60 个周期的训练后,GAN 生成的面孔。 该图像取自这里。
前面面孔中呈现的大量真实感证明了 GAN 的力量–在为他们提供良好的训练样本量之后,他们几乎可以学习生成任何类型的模式。
GAN 的核心概念围绕两个玩家玩游戏的想法。 在这个游戏中,一个人说出一个随机句子,另一个人仅仅考虑第一人称使用的单词就指出它是事实还是假。 第二个人唯一可以使用的知识是假句子和实句中常用的单词(以及如何使用)。 这可以描述为由 minimax 算法玩的两人游戏,其中每个玩家都试图以其最大能力抵消另一位玩家所做的移动。 在 GAN 中,第一个玩家是生成器(G
),第二个玩家是判别器(D
)。 G
和D
都是常规 GAN 中的神经网络。 生成器从训练数据集中给出的样本中学习,并基于其认为当观察者查看时可以作为真实样本传播的样本来生成新样本。
判别器从训练样本(正样本)和生成器生成的样本(负样本)中学习,并尝试对哪些图像存在于数据集中以及哪些图像进行分类。 它从G
获取生成的图像,并尝试将其分类为真实图像(存在于训练样本中)或生成图像(不存在于数据库中)。
通过反向传播,GAN 尝试不断减少判别器能够对生成器正确生成的图像进行分类的次数。 一段时间后,我们希望达到识别器在识别生成的图像时开始表现不佳的阶段。 这是 GAN 停止学习的地方,然后可以使用生成器生成所需数量的新样本。 因此,训练 GAN 意味着训练生成器以从随机输入产生输出,从而使判别器无法将其识别为生成的图像。
判别器将传递给它的所有图像分为两类:
- 真实图像:数据集中存在的图像或使用相机拍摄的图像
- 伪图像:使用某软件生成的图像
生成器欺骗判别器的能力越好,当向其提供任何随机输入序列时,生成的输出将越真实。
让我们以图表形式总结前面关于 GAN 进行的讨论:
GAN 具有许多不同的变体,所有变体都取决于它们正在执行的任务。 其中一些如下:
- 渐进式 GAN:在 ICLR 2018 上的一篇论文中介绍,渐进式 GAN 的生成器和判别器均以低分辨率图像开始,并随着图像层的增加而逐渐受到训练,从而使系统能够生成高分辨率图像。 例如,在第一次迭代中生成的图像为
10x10
像素,在第二代中它变为20x20
,依此类推,直到获得非常高分辨率的图像为止。 生成器和判别器都在深度上一起增长。 - 条件 GAN:假设您有一个 GAN 可以生成 10 个不同类别的样本,但是在某个时候,您希望它在给定类别或一组类别内生成样本。 这是有条件 GAN 起作用的时候。有条件 GAN 使我们可以生成 GAN 中经过训练可以生成的所有标签中任何给定标签的样本。 在图像到图像的翻译领域中,已经完成了条件 GAN 的一种非常流行的应用,其中将一个图像生成为相似或相同域的另一个更逼真的图像。 您可以通过这个页面上的演示来尝试涂鸦一些猫,并获得涂鸦的真实感版本。
- 栈式 GAN:栈式 GAN 的最流行的应用是基于文本描述生成图像。 在第一阶段,GAN 生成描述项的概述,在第二阶段,根据描述添加颜色。 然后,后续层中的 GAN 将更多细节添加到图像中,以生成图像的真实感版本,如描述中所述。 通过观察堆叠 GAN 的第一次迭代中的图像已经处于将要生成最终输出的尺寸,可以将栈式 GAN 与渐进式 GAN 区别开来。但是,与渐进式 GAN 相似,在第一次迭代中, 图像是最小的,并且需要进一步的层才能将其馈送到判别器。
在此项目中,我们将讨论 GAN 的另一种形式,称为超分辨率 GAN(SRGAN)。 我们将在下一部分中了解有关此变体的更多信息。
了解图像超分辨率的工作原理
几十年来,人们一直在追求并希望能够使低分辨率图像更加精细,以及使高分辨率图像化。 超分辨率是用于将低分辨率图像转换为超高分辨率图像的技术的集合,是图像处理工程师和研究人员最激动人心的工作领域之一。 已经建立了几种方法和方法来实现图像的超分辨率,并且它们都朝着自己的目标取得了不同程度的成功。 然而,近来,随着 SRGAN 的发展,关于使用任何低分辨率图像可以实现的超分辨率的量有了显着的改进。
但是在讨论 SRGAN 之前,让我们了解一些与图像超分辨率有关的概念。
了解图像分辨率
用质量术语来说,图像的分辨率取决于其清晰度。 分辨率可以归类为以下之一:
- 像素分辨率
- 空间分辨率
- 时间分辨率
- 光谱分辨率
- 辐射分辨率
让我们来看看每个。
像素分辨率
指定分辨率的最流行格式之一,像素分辨率最通常是指形成图像时涉及的像素数量。 单个像素是可以在任何给定查看设备上显示的最小单个单元。 可以将几个像素组合在一起以形成图像。 在本书的前面,我们讨论了图像处理,并将像素称为存储在矩阵中的颜色信息的单个单元,它代表图像。 像素分辨率定义了形成数字图像所需的像素元素总数,该总数可能与图像上可见的有效像素数不同。
标记图像像素分辨率的一种非常常见的表示法是以百万像素表示。 给定NxM
像素分辨率的图像,其分辨率可以写为(NxM / 1000000
)百万像素。 因此,尺寸为2,000x3,000
的图像将具有 6,000,000 像素,其分辨率可以表示为 6 兆像素。
空间分辨率
这是观察图像的人可以分辨图像中紧密排列的线条的程度的度量。 在这里,严格说来,图像的像素越多,清晰度越好。 这是由于具有较高像素数量的图像的空间分辨率较低。 因此,需要良好的空间分辨率以及具有良好的像素分辨率以使图像以良好的质量呈现。
它也可以定义为像素一侧所代表的距离量。
时间分辨率
分辨率也可能取决于时间。 例如,卫星或使用无人飞行器(UAV)无人机拍摄的同一区域的图像可能会随时间变化。 重新捕获相同区域的图像所需的时间称为时间分辨率。
时间分辨率主要取决于捕获图像的设备。 如在图像捕捉的情况下,这可以是变型,例如当在路边的速度陷阱照相机中触发特定传感器时执行图像捕捉。 它也可以是常数。 例如,在配置为每x
间隔拍照的相机中。
光谱分辨率
光谱分辨率是指图像捕获设备可以记录的波段数。 也可以将其定义为波段的宽度或每个波段的波长范围。 在数字成像方面,光谱分辨率类似于图像中的通道数。 理解光谱分辨率的另一种方法是在任何给定图像或频带记录中可区分的频带数。
黑白图像中的波段数为 1,而彩色(RGB)图像中的波段数为 3。可以捕获数百个波段的图像,其中其他波段可提供有关图像的不同种类的信息。 图片。
辐射分辨率
辐射分辨率是捕获设备表示在任何频带/通道上接收到的强度的能力。 辐射分辨率越高,设备可以更准确地捕获其通道上的强度,并且图像越真实。
辐射分辨率类似于图像每个像素的位数。 虽然 8 位图像像素可以表示 256 个不同的强度,但是 256 位图像像素可以表示2 ^ 256
个不同的强度。 黑白图像的辐射分辨率为 1 位,这意味着每个像素只能有两个不同的值,即 0 和 1。
现在,让我们尝试了解 SRGAN。
了解 SRGAN
SRGAN 是一类 GAN,主要致力于从低分辨率图像创建超分辨率图像。
SRGAN 算法的功能描述如下:该算法从数据集中选取高分辨率图像,然后将其采样为低分辨率图像。 然后,生成器神经网络尝试从低分辨率图像生成高分辨率图像。 从现在开始,我们将其称为超分辨率图像。 将超分辨率图像发送到鉴别神经网络,该神经网络已经在高分辨率图像和一些基本的超分辨率图像的样本上进行了训练,以便可以对它们进行分类。
判别器将由生成器发送给它的超分辨率图像分类为有效的高分辨率图像,伪高分辨率图像或超分辨率图像。 如果将图像分类为超分辨率图像,则 GAN 损失会通过生成器网络反向传播,以便下次产生更好的伪造图像。 随着时间的流逝,生成器将学习如何创建更好的伪造品,并且判别器开始无法正确识别超分辨率图像。 GAN 在这里停止学习,被列为受过训练的人。
可以用下图来总结:
现在,让我们开始创建用于超分辨率的 SRGAN 模型。
创建 TensorFlow 模型来实现超分辨率
现在,我们将开始构建在图像上执行超分辨率的 GAN 模型。 在深入研究代码之前,我们需要了解如何组织项目目录。
项目目录结构
本章中包含以下文件和文件夹:
api/
:model /
:__init __.py
:此文件指示此文件的父文件夹可以像模块一样导入。common.py
:包含任何 GAN 模型所需的常用函数。srgan.py
:其中包含开发 SRGAN 模型所需的函数。weights/
:gan_generator.h5
:模型的预训练权重文件。 随意使用它来快速运行并查看项目的工作方式。data.py
:用于在 DIV2K 数据集中下载,提取和加载图像的工具函数。flask_app.py
:我们将使用此文件来创建将在 DigitalOcean 上部署的服务器。train.py
:模型训练文件。 我们将在本节中更深入地讨论该文件。
您可以在这个页面中找到项目此部分的源代码。
多样 2K(DIV2K)数据集由图像恢复和增强的新趋势(NTIRE)2017 单张图像超分辨率挑战赛引入,也用于挑战赛的 2018 版本中。
在下一节中,我们将构建 SRGAN 模型脚本。
创建用于超分辨率的 SRGAN 模型
首先,我们将从处理train.py
文件开始:
- 让我们从将必要的模块导入项目开始:
import os from data import DIV2K from model.srgan import generator, discriminator from train import SrganTrainer, SrganGeneratorTrainer
前面的导入引入了一些现成的类,例如SrganTrainer
,SrganGeneratorTrainer
等。 在完成此文件的工作后,我们将详细讨论它们。
- 现在,让我们为权重创建一个目录。 我们还将使用此目录来存储中间模型:
weights_dir = 'weights' weights_file = lambda filename: os.path.join(weights_dir, filename) os.makedirs(weights_dir, exist_ok=True)
- 接下来,我们将从 DIV2K 数据集中下载并加载图像。 我们将分别下载训练和验证图像。 对于这两组图像,可以分为两对:高分辨率和低分辨率。 但是,这些是单独下载的:
div2k_train = DIV2K(scale=4, subset='train', downgrade='bicubic') div2k_valid = DIV2K(scale=4, subset='valid', downgrade='bicubic')
- 将数据集下载并加载到变量后,我们需要将训练图像和验证图像都转换为 TensorFlow 数据集对象。 此步骤还将两个数据集中的高分辨率和低分辨率图像结合在一起:
train_ds = div2k_train.dataset(batch_size=16, random_transform=True) valid_ds = div2k_valid.dataset(batch_size=16, random_transform=True, repeat_count=1)
- 现在,回想一下我们在“了解 GAN”部分中提供的 GAN 的定义。 为了使生成器开始产生判别器可以评估的伪造品,它需要学习创建基本的伪造品。 为此,我们将快速训练神经网络,以便它可以生成基本的超分辨率图像。 我们将其命名为预训练器。 然后,我们将预训练器的权重迁移到实际的 SRGAN,以便它可以通过使用判别器来学习更多。 让我们构建并运行预训练器:
pre_trainer = SrganGeneratorTrainer(model=generator(), checkpoint_dir=f'.ckpt/pre_generator') pre_trainer.train(train_ds, valid_ds.take(10), steps=1000000, evaluate_every=1000, save_best_only=False) pre_trainer.model.save_weights(weights_file('pre_generator.h5'))
现在,我们已经训练了一个基本模型并保存了权重。 我们可以随时更改 SRGAN 并通过加载其权重从基础训练中重新开始。
- 现在,让我们将预训练器权重加载到 SRGAN 对象中,并执行训练迭代:
gan_generator = generator() gan_generator.load_weights(weights_file('pre_generator.h5')) gan_trainer = SrganTrainer(generator=gan_generator, discriminator=discriminator()) gan_trainer.train(train_ds, steps=200000)
请注意,在具有 8 GB RAM 和 Intel i7 处理器的普通计算机上,上述代码中的训练操作可能会花费大量时间。 建议在具有图形处理器(GPU)的基于云的虚拟机中执行此训练。
- 现在,让我们保存 GAN 生成器和判别器的权重:
gan_trainer.generator.save_weights(weights_file('gan_generator.h5')) gan_trainer.discriminator.save_weights(weights_file('gan_discriminator.h5'))
现在,我们准备继续进行下一部分,在该部分中将构建将使用此模型的 Flutter 应用的 UI。
构建应用的 UI
现在,我们了解了图像超分辨率模型的基本功能并为其创建了一个模型,让我们深入研究构建 Flutter 应用。 在本节中,我们将构建应用的 UI。
该应用的用户界面非常简单:它将包含两个图像小部件和按钮小部件。 当用户单击按钮小部件时,他们将能够从设备的库中选择图像。 相同的图像将作为输入发送到托管模型的服务器。 服务器将返回增强的图像。 屏幕上将放置的两个图像小部件将用于显示服务器的输入和服务器的输出。
下图说明了应用的基本结构和最终流程:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-mNqyudKm-1681785128426)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/mobi-dl-tflite/img/86a43bbe-4673-4dcb-8a8c-591d7c952df0.png)]
该应用的三个主要小部件可以简单地排列在一列中。 该应用的小部件树如下所示:
现在,让我们编写代码以构建主屏幕。 以下步骤讨论了该应用小部件的创建和放置:
- 首先,我们创建一个名为
image_super_resolution.dart
的新文件。 这将包含一个名为ImageSuperResolution
的无状态窗口小部件。 该小部件将包含应用主屏幕的代码。 - 接下来,我们将定义一个名为
buildImageInput()
的函数,该函数返回一个小部件,该小部件负责显示用户选择的图像:
Widget buildImage1() { return Expanded( child: Container( width: 200, height: 200, child: img1 ) ); }
此函数返回带有Container
作为其child.
的Expanded
小部件。Container
的width
和height
为200
。 Container
的子元素最初是存储在资产文件夹中的占位符图像,可以通过img1
变量进行访问,如下所示:
var img1 = Image.asset('assets/place_holder_image.png');
我们还将在pubspec.yaml
文件中添加图像的路径,如下所示:
flutter: assets: - assets/place_holder_image.png
- 现在,我们将创建另一个函数
buildImageOutput()
,该函数返回一个小部件,该小部件负责显示模型返回的增强图像:
Widget buildImageOutput() { return Expanded( child: Container( width: 200, height: 200, child: imageOutput ) ); }
此函数返回一个以其Container
作为其子元素的Expanded
小部件。 Container
的宽度和高度设置为200
。 Container
的子级是名为imageOutput
的小部件。 最初,imageOutput
还将包含一个占位符图像,如下所示:
Widget imageOutput = Image.asset('assets/place_holder_image.png');
将模型集成到应用中后,我们将更新imageOutput
。
- 现在,我们将定义第三个函数
buildPickImageButton()
,该函数返回一个Widget
,我们可以使用它从设备的图库中选择图像:
Widget buildPickImageButton() { return Container( margin: EdgeInsets.all(8), child: FloatingActionButton( elevation: 8, child: Icon(Icons.camera_alt), onPressed: () => {}, ) ); }
此函数返回以FloatingActionButton
作为其子元素的Container
。 按钮的elevation
属性控制其下方阴影的大小,并设置为8
。 为了反映该按钮用于选择图像,通过Icon
类为它提供了摄像机的图标。 当前,我们已经将按钮的onPressed
属性设置为空白。 我们将在下一部分中定义一个函数,使用户可以在按下按钮时从设备的图库中选择图像。
- 最后,我们将覆盖
build
方法以返回应用的Scaffold
:
@override Widget build(BuildContext context) { return Scaffold( appBar: AppBar(title: Text('Image Super Resolution')), body: Container( child: Column( crossAxisAlignment: CrossAxisAlignment.center, children: <Widget>[ buildImageInput(), buildImageOutput(), buildPickImageButton() ] ) ) ); }
Scaffold
包含一个appBar
,其标题设置为“图像超分辨率”。 Scaffold
的主体为Container
,其子代为Column
。 该列的子级是我们在先前步骤中构建的三个小部件。 另外,我们将Column
的crossAxisAlignment
属性设置为CrossAxisAlignment.center
,以确保该列位于屏幕的中央。
至此,我们已经成功构建了应用的初始状态。 以下屏幕截图显示了该应用现在的外观:
尽管屏幕看起来很完美,但目前无法正常工作。 接下来,我们将向应用添加功能。 我们将添加让用户从图库中选择图像的功能。
TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11(5)https://developer.aliyun.com/article/1427024