卷积神经元网络中常用卷积核理解及基于Pytorch的实例应用(附完整代码)

简介: 卷积神经元网络中常用卷积核理解及基于Pytorch的实例应用(附完整代码)
0.前言

本文的目的:说明在卷积操作中常用的卷积核及其作用原理,并基于Pytorch框架通过实例使用这些卷积核。


看本文前要掌握的基础知识:需要了解卷积神经元网络CNN中卷积的运算原理,CSDN上此类文章很多,不再赘述。推荐一篇:RGB彩色图像的卷积过程(gif动图演示)


1.常用卷积核

卷积核(kernel)是卷积神经元网络CNN中最重要的权重参数,在卷积神经元网络学习过程中,主要也是为了学习到合适的卷积核。下面将通过对图像的处理方式将常用的卷积核分为3类。


1.1 边缘识别类卷积核

这类卷积核是研究最多的,卷积核多种多样。这类卷积核的共同特征是:卷积核内所有的值求和为0,这是因为边缘的区域,图像的像素值会发生突变,与这样的卷积核做卷积会得到一个不为0的值。而非边缘的区域,像素值很接近,与这样的卷积核做卷积会得到一个约等于0的值。


常用的边缘检测卷积核有:

①Robert算子

image.png

image.png

②Prewitt算子

image.png

image.png

③Sobel算子

image.png

image.png

④Laplace算子

image.png

1.2 模糊化卷积核

这类卷积核的作用原理是对一片区域内的像素值求平均值,使得像素变化更加平缓,达到模糊化的目的,例如:

image.png

1.3 锐利化卷积核

这类卷积核的作用是凸显像素值有变化的区域,使得本来像素值梯度就比较大的区域(边缘区域)变得像素值梯度更大。在边缘检测中,卷积核的设计要求卷积核内的所有值求和为0,这里的要求刚好相反,要求卷积核内的所有值应该不为0,凸显出像素值梯度较大的区域,例如以下卷积核:

image.png

只要把握以上卷积核的原理,就可以自己设计卷积核。比如模糊化卷积核,这样也是可以的

image.png
2. 基于Pytorch的卷积核实例应用
2.1 tensor与图像的互相转换

①image转tensor:使用PIL中Image.open()打开图像,然后使用torchvision.transforms.ToTensor()转为tensor:

image = Image.open('image_path').convert('RGB') #导入图片
image_to_tensor = torchvision.transforms.ToTensor()   #实例化ToTensor
original_image_tensor = image_to_tensor(image).unsqueeze(0)     #把图片转换成tensor

这里使用.unsqueeze(0)升维的目的是为了后面进行卷积操作准备,因为Conv2d要求输入的tensor维度为4维,即[batch, channel, H, W]。这里对应要增加batch这个维度。

②tensor转image:使用torchvision.utils.save_image():

torchvision.utils.save_image(tensor, 'save_image_path')
2.2 卷积核的指定

因为nn.Conv2d()方法中,卷积核(权重)是默认随机的,所以需要先指定好卷积核:

#卷积核:laplace
conv_laplace = torch.nn.Conv2d(in_channels=3,out_channels=1,kernel_size=3,padding=0,bias=False)
conv_laplace.weight.data = torch.tensor([[[[-1,0,1],[-1,0,1],[-1,0,1]],
                                            [[-1,0,1],[-1,0,1],[-1,0,1]],
                                            [[-1,0,1],[-1,0,1],[-1,0,1]]]], dtype=torch.float32)

注意卷积核的维度,也是上面说的要有4个维度。

3.不同卷积核的图像处理结果

①原图像:

②边缘检测卷积核(prewitt横向算子)卷积后:

③边缘检测卷积核(prewitt纵向算子)卷积后:

④边缘检测卷积核(laplace算子)卷积后:

体验一下同样是边缘检测,以上3个算子的输出差异。尤其是横向和纵向上的差别。

⑤模糊化卷积核卷积后:

⑥锐利化卷积核卷积后:

这里把背景的噪声点都凸显出来了。

4.完整代码
import torch
from PIL import Image
import torchvision


image = Image.open('girl.png').convert('RGB') #导入图片
image_to_tensor = torchvision.transforms.ToTensor()   #实例化ToTensor
original_image_tensor = image_to_tensor(image).unsqueeze(0)     #把图片转换成tensor


#卷积核:prewitt横向
conv_prewitt_h = torch.nn.Conv2d(in_channels=3,out_channels=1,kernel_size=3,padding=0,bias=False)  #bias要设定成False,要不然会随机生成bias,每次结果都不一样
conv_prewitt_h.weight.data = torch.tensor([[[[-1,-1,-1],[0,0,0],[1,1,1]],
                                            [[-1,-1,-1],[0,0,0],[1,1,1]],
                                            [[-1,-1,-1],[0,0,0],[1,1,1]]]], dtype=torch.float32)


#卷积核:prewitt纵向
conv_prewitt_l = torch.nn.Conv2d(in_channels=3,out_channels=1,kernel_size=3,padding=0,bias=False)
conv_prewitt_l.weight.data = torch.tensor([[[[-1,0,1],[-1,0,1],[-1,0,1]],
                                            [[-1,0,1],[-1,0,1],[-1,0,1]],
                                            [[-1,0,1],[-1,0,1],[-1,0,1]]]], dtype=torch.float32)

#卷积核:laplace
conv_laplace = torch.nn.Conv2d(in_channels=3,out_channels=1,kernel_size=3,padding=0,bias=False)
conv_laplace.weight.data = torch.tensor([[[[-1,0,1],[-1,0,1],[-1,0,1]],
                                            [[-1,0,1],[-1,0,1],[-1,0,1]],
                                            [[-1,0,1],[-1,0,1],[-1,0,1]]]], dtype=torch.float32)


#卷积核:模糊化
conv_blur = torch.nn.Conv2d(in_channels=3,out_channels=1,kernel_size=5,padding=0,bias=False)
conv_blur.weight.data = torch.full((1,3,5,5),0.04)


#卷积核:锐利化
conv_sharp = torch.nn.Conv2d(in_channels=3,out_channels=1,kernel_size=3,padding=0,bias=False)
conv_sharp.weight.data = torch.tensor([[[[-1,-1,1],[-1,-1,-1],[-1,-1,-1]],
                                            [[-1,-1,1],[-1,22,-1],[-1,-1,-1]],
                                            [[-1,-1,1],[-1,-1,-1],[-1,-1,-1]]]], dtype=torch.float32)

#生成并保存图片
tensor_prewitt_h = conv_prewitt_h(original_image_tensor)
torchvision.utils.save_image(tensor_prewitt_h, 'prewitt_h.png')

tensor_prewitt_l = conv_prewitt_l(original_image_tensor)
torchvision.utils.save_image(tensor_prewitt_l, 'prewitt_l.png')

tensor_laplace = conv_laplace(original_image_tensor)
torchvision.utils.save_image(tensor_laplace, 'laplace.png')

tensor_blur = conv_blur(original_image_tensor)
torchvision.utils.save_image(tensor_blur, 'blur.png')

tensor_sharp = conv_sharp(original_image_tensor)
torchvision.utils.save_image(tensor_sharp, 'sharp.png')


相关文章
|
24天前
|
机器学习/深度学习 关系型数据库 MySQL
大模型中常用的注意力机制GQA详解以及Pytorch代码实现
GQA是一种结合MQA和MHA优点的注意力机制,旨在保持MQA的速度并提供MHA的精度。它将查询头分成组,每组共享键和值。通过Pytorch和einops库,可以简洁实现这一概念。GQA在保持高效性的同时接近MHA的性能,是高负载系统优化的有力工具。相关论文和非官方Pytorch实现可进一步探究。
71 4
|
3天前
|
移动开发 Java Android开发
构建高效Android应用:采用Kotlin协程优化网络请求
【4月更文挑战第24天】 在移动开发领域,尤其是对于Android平台而言,网络请求是一个不可或缺的功能。然而,随着用户对应用响应速度和稳定性要求的不断提高,传统的异步处理方式如回调地狱和RxJava已逐渐显示出局限性。本文将探讨如何利用Kotlin协程来简化异步代码,提升网络请求的效率和可读性。我们将深入分析协程的原理,并通过一个实际案例展示如何在Android应用中集成和优化网络请求。
|
9天前
|
存储 监控 安全
网络安全与信息安全:防范漏洞、应用加密、提升意识
【4月更文挑战第18天】 在数字化时代,网络安全与信息安全保障已成为维护国家安全、企业利益和个人隐私的关键。本文深入探讨网络安全的多面性,包括识别和防御网络漏洞、应用加密技术保护数据以及提升全民网络安全意识的重要性。通过对这些关键领域的分析,文章旨在为读者提供实用的策略和建议,以增强其网络环境的安全防护能力。
10 0
|
10天前
|
数据采集 机器学习/深度学习 数据挖掘
网络数据处理中的NumPy应用实战
【4月更文挑战第17天】本文介绍了NumPy在网络数据处理中的应用,包括数据预处理、流量分析和模式识别。通过使用NumPy进行数据清洗、格式化和聚合,以及处理时间序列数据和计算统计指标,可以有效进行流量分析和异常检测。此外,NumPy还支持相关性分析、周期性检测和聚类分析,助力模式识别。作为强大的科学计算库,NumPy在处理日益增长的网络数据中发挥着不可或缺的作用。
|
16天前
|
机器学习/深度学习 PyTorch API
|
18天前
|
传感器 监控 安全
|
18天前
|
安全 SDN 数据中心
|
18天前
|
安全 网络安全 网络虚拟化
虚拟网络设备与网络安全:深入分析与实践应用
在数字化时代📲,网络安全🔒成为了企业和个人防御体系中不可或缺的一部分。随着网络攻击的日益复杂和频繁🔥,传统的物理网络安全措施已经无法满足快速发展的需求。虚拟网络设备🖧,作为网络架构中的重要组成部分,通过提供灵活的配置和强大的隔离能力🛡️,为网络安全提供了新的保障。本文将从多个维度深入分析虚拟网络设备是如何保障网络安全的,以及它们的实际意义和应用场景。
|
22天前
|
PyTorch 算法框架/工具
使用Pytorch Geometric 进行链接预测代码示例
该代码示例使用PyTorch和`torch_geometric`库实现了一个简单的图卷积网络(GCN)模型,处理Cora数据集。模型包含两层GCNConv,每层后跟ReLU激活和dropout。模型在训练集上进行200轮训练,使用Adam优化器和交叉熵损失函数。最后,计算并打印测试集的准确性。
15 6
|
1月前
|
机器学习/深度学习 数据采集 人工智能
m基于深度学习网络的手势识别系统matlab仿真,包含GUI界面
m基于深度学习网络的手势识别系统matlab仿真,包含GUI界面
43 0