使用TensorFlow创建能够图像重建的自编码器模型

本文涉及的产品
交互式建模 PAI-DSW,每月250计算时 3个月
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
模型训练 PAI-DLC,5000CU*H 3个月
简介: 使用TensorFlow创建能够图像重建的自编码器模型

想象你正在解决一个拼图游戏。你已经完成了大部分。假设您需要在一幅几乎完成的图片中间修复一块。你需要从盒子里选择一块,它既适合空间,又能完成整个画面。

640.png

我相信你很快就能做到。但是你的大脑是怎么做到的呢?

首先,它会分析空槽周围的图片(在这里你需要固定拼图的一块)。如果图片中有一棵树,你会寻找绿色的部分(这是显而易见的!)所以,简而言之,我们的大脑能够通过知道图像周围的环境来预测图像(它将适合放入槽中)。

在本教程中,我们的模型将执行类似的任务。它将学习图像的上下文,然后利用学习到的上下文预测图像的一部分(缺失的部分)。

在这篇文章之前,我们先看一下代码实现

我建议您在另一个选项卡中打开这个笔记本(TF实现),这样您就可以直观地了解发生了什么。

https://colab.research.google.com/drive/1zFe9TmMCK2ldUOsVXenvpbNY2FLrLh5k#scrollTo=UXjElGKzyiey&forceEdit=true&sandboxMode=true

问题

我们希望我们的模型能预测图像的一部分。给定一个有部份缺失图像(只有0的图像阵列的一部分),我们的模型将预测原始图像是完整的。

因此,我们的模型将利用它在训练中学习到的上下文重建图像中缺失的部分。

640.png

数据

我们将为任务选择一个域。我们选择了一些山地图像,它们是Puneet Bansal在Kaggle上的 Intel Image Classification数据集的一部分。

为什么只有山脉的图像?

在这里,我们选择属于某个特定域的图像。如果我们选择的数据集中有更广泛图像,我们的模型将不能很好地执行。因此,我们将其限制在一个域内。

使用wget下载我在GitHub上托管的数据

!wgethttps://github.com/shubham0204/Dataset_Archives/blob/master/mountain_images.zip?raw=true -O images.zip!unzipimages.zip

为了生成训练数据,我们将遍历数据集中的每个图像,并对其执行以下任务,

640.png

首先,我们将使用PIL.Image.open()读取图像文件。使用np.asarray()将这个图像对象转换为一个NumPy数组。

确定窗口大小。这是正方形的边长这是从原始图像中得到的。

[ 0 , image_dim — window_size ]范围内生成2个随机数。image_dim是我们的方形输入图像的大小。

这两个数字(称为px和py)是从原始图像剪裁的位置。选择图像数组的一部分,并将其替换为零数组。

代码如下

x= []
y= []
input_size= ( 228 , 228 , 3 )
#Takeoutasquareregionofside50px.
window_size=50#Storetheoriginalimagesastargetimages.
fornameinos.listdir( 'mountain_images/' ):
image=Image.open( 'mountain_images/{}'.format( name ) ).resize( input_size[0:2] )
image=np.asarray( image ).astype( np.uint8 )
y.append( image )
fornameinos.listdir( 'mountain_images/' ):
image=Image.open( 'mountain_images/{}'.format( name ) ).resize( input_size[0:2] )
image=np.asarray( image ).astype( np.uint8 )
#GeneraterandomXandYcoordinateswithintheimagebounds.
px , py=random.randint( 0 , input_size[0] -window_size ) , random.randint( 0 , input_size[0] -window_size )
#Takethatpartoftheimageandreplaceitwithazeroarray. Thismakesthe"missing"partoftheimage.
image[ px : px+window_size , py : py+window_size , 0:3 ] =np.zeros( ( window_size , window_size , 3 ) )
#Appendittoanarrayx.append( image )
#Normalizetheimagesx=np.array( x ) /255y=np.array( y ) /255#Traintestsplitx_train, x_test, y_train, y_test=train_test_split( x , y , test_size=0.2 )

自动编码器模型与跳连接

我们添加跳转连接到我们的自动编码器模型。这些跳过连接提供了更好的上采样。通过使用最大池层,许多空间信息会在编码过程中丢失。为了从它的潜在表示(由编码器产生)重建图像,我们添加了跳过连接,它将信息从编码器带到解码器。

alpha=0.2inputs=Input( shape=input_size )
conv1=Conv2D( 32 , kernel_size=( 3 , 3 ) , strides=1 )( inputs )
relu1=LeakyReLU( alpha )( conv1 )
conv2=Conv2D( 32 , kernel_size=( 3 , 3 ) , strides=1 )( relu1 )
relu2=LeakyReLU( alpha )( conv2 )
maxpool1=MaxPooling2D()( relu2 )
conv3=Conv2D( 64 , kernel_size=( 3 , 3 ) , strides=1 )( maxpool1 )
relu3=LeakyReLU( alpha )( conv3 )
conv4=Conv2D( 64 , kernel_size=( 3 , 3 ) , strides=1 )( relu3 )
relu4=LeakyReLU( alpha )( conv4 )
maxpool2=MaxPooling2D()( relu4 )
conv5=Conv2D( 128 , kernel_size=( 3 , 3 ) , strides=1 )( maxpool2 )
relu5=LeakyReLU( alpha )( conv5 )
conv6=Conv2D( 128 , kernel_size=( 3 , 3 ) , strides=1 )( relu5 )
relu6=LeakyReLU( alpha )( conv6 )
maxpool3=MaxPooling2D()( relu6 )
conv7=Conv2D( 256 , kernel_size=( 1 , 1 ) , strides=1 )( maxpool3 )
relu7=LeakyReLU( alpha )( conv7 )
conv8=Conv2D( 256 , kernel_size=( 1 , 1 ) , strides=1 )( relu7 )
relu8=LeakyReLU( alpha )( conv8 )
upsample1=UpSampling2D()( relu8 )
concat1=Concatenate()([ upsample1 , conv6 ])
convtranspose1=Conv2DTranspose( 128 , kernel_size=( 3 , 3 ) , strides=1)( concat1 )
relu9=LeakyReLU( alpha )( convtranspose1 )
convtranspose2=Conv2DTranspose( 128 , kernel_size=( 3 , 3 ) , strides=1 )( relu9 )
relu10=LeakyReLU( alpha )( convtranspose2 )
upsample2=UpSampling2D()( relu10 )
concat2=Concatenate()([ upsample2 , conv4 ])
convtranspose3=Conv2DTranspose( 64 , kernel_size=( 3 , 3 ) , strides=1)( concat2 )
relu11=LeakyReLU( alpha )( convtranspose3 )
convtranspose4=Conv2DTranspose( 64 , kernel_size=( 3 , 3 ) , strides=1 )( relu11 )
relu12=LeakyReLU( alpha )( convtranspose4 )
upsample3=UpSampling2D()( relu12 )
concat3=Concatenate()([ upsample3 , conv2 ])
convtranspose5=Conv2DTranspose( 32 , kernel_size=( 3 , 3 ) , strides=1)( concat3 )
relu13=LeakyReLU( alpha )( convtranspose5 )
convtranspose6=Conv2DTranspose( 3 , kernel_size=( 3 , 3 ) , strides=1 , activation='relu' )( relu13 )
model=tf.keras.models.Model( inputs , convtranspose6 )
model.compile( loss='mse' , optimizer='adam' , metrics=[ 'mse' ] )

最后,训练我们的自动编码器模型,


model.fit( x_train , y_train , epochs=150 , batch_size=25 , validation_data=( x_test , y_test ) )

image.png

结论

以上结果是在少数测试图像上得到的。我们观察到模型几乎已经学会了如何填充黑盒!但我们仍然可以分辨出盒子在原始图像中的位置。这样,我们就可以建立一个模型来预测图像缺失的部分。

这里我们只是用了一个简单的模型来作为样例,如果我们要推广到现实生活中,就需要使用更大的数据集和更深的网络,例如可以使用现有的sota模型,加上imagenet的图片进行训练。

目录
相关文章
|
2月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
将Keras训练好的.hdf5模型转换为TensorFlow的.pb模型,然后再转换为TensorRT支持的.uff格式,并提供了转换代码和测试步骤。
91 3
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
|
14天前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
50 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
14天前
|
机器学习/深度学习 人工智能 算法
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
蔬菜识别系统,本系统使用Python作为主要编程语言,通过收集了8种常见的蔬菜图像数据集('土豆', '大白菜', '大葱', '莲藕', '菠菜', '西红柿', '韭菜', '黄瓜'),然后基于TensorFlow搭建卷积神经网络算法模型,通过多轮迭代训练最后得到一个识别精度较高的模型文件。在使用Django开发web网页端操作界面,实现用户上传一张蔬菜图片识别其名称。
57 0
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
|
1月前
|
机器学习/深度学习 人工智能 算法
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
车辆车型识别,使用Python作为主要编程语言,通过收集多种车辆车型图像数据集,然后基于TensorFlow搭建卷积网络算法模型,并对数据集进行训练,最后得到一个识别精度较高的模型文件。再基于Django搭建web网页端操作界面,实现用户上传一张车辆图片识别其类型。
74 0
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
|
3月前
|
机器学习/深度学习 人工智能 算法
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
鸟类识别系统。本系统采用Python作为主要开发语言,通过使用加利福利亚大学开源的200种鸟类图像作为数据集。使用TensorFlow搭建ResNet50卷积神经网络算法模型,然后进行模型的迭代训练,得到一个识别精度较高的模型,然后在保存为本地的H5格式文件。在使用Django开发Web网页端操作界面,实现用户上传一张鸟类图像,识别其名称。
110 12
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
|
2月前
|
机器学习/深度学习 移动开发 TensorFlow
深度学习之格式转换笔记(四):Keras(.h5)模型转化为TensorFlow(.pb)模型
本文介绍了如何使用Python脚本将Keras模型转换为TensorFlow的.pb格式模型,包括加载模型、重命名输出节点和量化等步骤,以便在TensorFlow中进行部署和推理。
87 0
|
4月前
|
API UED 开发者
如何在Uno Platform中轻松实现流畅动画效果——从基础到优化,全方位打造用户友好的动态交互体验!
【8月更文挑战第31天】在开发跨平台应用时,确保用户界面流畅且具吸引力至关重要。Uno Platform 作为多端统一的开发框架,不仅支持跨系统应用开发,还能通过优化实现流畅动画,增强用户体验。本文探讨了Uno Platform中实现流畅动画的多个方面,包括动画基础、性能优化、实践技巧及问题排查,帮助开发者掌握具体优化策略,提升应用质量与用户满意度。通过合理利用故事板、减少布局复杂性、使用硬件加速等技术,结合异步方法与预设缓存技巧,开发者能够创建美观且流畅的动画效果。
84 0
|
4月前
|
C# 开发者 前端开发
揭秘混合开发新趋势:Uno Platform携手Blazor,教你一步到位实现跨平台应用,代码复用不再是梦!
【8月更文挑战第31天】随着前端技术的发展,混合开发日益受到开发者青睐。本文详述了如何结合.NET生态下的两大框架——Uno Platform与Blazor,进行高效混合开发。Uno Platform基于WebAssembly和WebGL技术,支持跨平台应用构建;Blazor则让C#成为可能的前端开发语言,实现了客户端与服务器端逻辑共享。二者结合不仅提升了代码复用率与跨平台能力,还简化了项目维护并增强了Web应用性能。文中提供了从环境搭建到示例代码的具体步骤,并展示了如何创建一个简单的计数器应用,帮助读者快速上手混合开发。
89 0
|
4月前
|
开发者 算法 虚拟化
惊爆!Uno Platform 调试与性能分析终极攻略,从工具运用到代码优化,带你攻克开发难题成就完美应用
【8月更文挑战第31天】在 Uno Platform 中,调试可通过 Visual Studio 设置断点和逐步执行代码实现,同时浏览器开发者工具有助于 Web 版本调试。性能分析则利用 Visual Studio 的性能分析器检查 CPU 和内存使用情况,还可通过记录时间戳进行简单分析。优化性能涉及代码逻辑优化、资源管理和用户界面简化,综合利用平台提供的工具和技术,确保应用高效稳定运行。
85 0
|
4月前
|
前端开发 开发者 设计模式
揭秘Uno Platform状态管理之道:INotifyPropertyChanged、依赖注入、MVVM大对决,帮你找到最佳策略!
【8月更文挑战第31天】本文对比分析了 Uno Platform 中的关键状态管理策略,包括内置的 INotifyPropertyChanged、依赖注入及 MVVM 框架。INotifyPropertyChanged 方案简单易用,适合小型项目;依赖注入则更灵活,支持状态共享与持久化,适用于复杂场景;MVVM 框架通过分离视图、视图模型和模型,使状态管理更清晰,适合大型项目。开发者可根据项目需求和技术栈选择合适的状态管理方案,以实现高效管理。
46 0

热门文章

最新文章