从零手写Resnet50实战——利用 torch 识别出了虎猫和萨摩耶

简介: 利用 torch 识别出了虎猫和萨摩耶

大家好啊,我是董董灿。

自从前几天手写了一个慢速卷积之后(从零手写Resnet50实战—手写龟速卷积),我便一口气将 Resnet50 中剩下的算法都写完了。

然后,暴力的,按照 Resnet50 的结构,将手写的算法一层层地连接了起来。

out = compute_conv_layer(out, "conv1")
out = compute_bn_layer(out, "bn1")
out = compute_relu_layer(out)
out = compute_maxpool_layer(out)
# layer1 
out = compute_bottleneck(out, "layer1_bottleneck0", down_sample = True)
out = compute_bottleneck(out, "layer1_bottleneck1", down_sample = False)
out = compute_bottleneck(out, "layer1_bottleneck2", down_sample = False)
# layer2
out = compute_bottleneck(out, "layer2_bottleneck0", down_sample = True)
out = compute_bottleneck(out, "layer2_bottleneck1", down_sample = False)
out = compute_bottleneck(out, "layer2_bottleneck2", down_sample = False)
out = compute_bottleneck(out, "layer2_bottleneck3", down_sample = False)
# layer3
out = compute_bottleneck(out, "layer3_bottleneck0", down_sample = True)
out = compute_bottleneck(out, "layer3_bottleneck1", down_sample = False)
out = compute_bottleneck(out, "layer3_bottleneck2", down_sample = False)
out = compute_bottleneck(out, "layer3_bottleneck3", down_sample = False)
out = compute_bottleneck(out, "layer3_bottleneck4", down_sample = False)
out = compute_bottleneck(out, "layer3_bottleneck5", down_sample = False)
# layer4
out = compute_bottleneck(out, "layer4_bottleneck0", down_sample = True)
out = compute_bottleneck(out, "layer4_bottleneck1", down_sample = False)
out = compute_bottleneck(out, "layer4_bottleneck2", down_sample = False)
# avg pool
out = compute_avgpool_layer(out)
# Linear
out = compute_fc_layer(out, "fc")

算法的手写和网络的搭建,没有调用任何第三方库,这也是这个项目的初衷。

相关代码都已经上传至:项目根目录/python/inference.py,项目地址在文章末尾。

试运行

在将网络搭建完的那一刻,迫不及待的我,赶紧运行了这个网络,试图让它能识别出这张猫,结果如我所想——识别错误!

image.png

image.png

我手写手搭的神经网络将这只猫的类别识别成了 bucket——一个水桶!

不过没关系,这不很正常么,谁能确保刚刚手写的一个神经网络,第一次运行就能成功呢?

马斯克不刚刚发射星舰还失败了吗?

识别错误了,那就开始调试。

于是,我快速地使用 torch 搭建了一个官方的 resent50 模型。然后利用这个官方的模型来推理了两张图片,一张是上面的小猫,一张是下面的狗子。

image.png

结果很明显,官方模型推理正确。

image.png

看到推理结果我才意识到,那只狗子是个 Samoyed——萨摩耶。

在确认了官方的模型和算法可以正确地识别出猫咪和狗子之后,我开始了漫长的debug(调试)之路。

开始调试

调试方法很简单,将 torch 官方的 resnet50 计算每一层得出的结果,和我手写的算法计算的结果对比。
在对比的过程中,真的发现了一个问题。

保存的权值数据与算法不匹配

torch 的权值默认是按照 NCHW 的格式存储的,而我手写算法的时候,习惯按照 NHWC 的格式来写,于是,我的第一层卷积就算错了。

于是,在导出权值的时候(从零手写Resnet50实战——权值另存为),额外增加一个 transpose 操作,将 torch 默认的 NCHW 的权值,转置为我手写算法需要的 NHWC 的权值。

然后继续保存到 txt 中。

在将权值的问题修复完之后,重新对比结果,竟出奇地顺畅。

一路绿灯。

从 conv1->bn1->maxpool,再到第一个layer 中的 conv1->bn1->conv2->bn2->conv3->bn3->relu ,甚至第一个layer中下采样中的计算 conv->bn。

这几个地方竟然全部能和官方resnet50计算的结果对的上!

image.png

也就是,在上图中,绿色部分都没问题,红色部分仍然存在计算错误。出错部分刚好是残差结构的加法操作。

不过这个结果确实是惊到我了,这说明——

我手写的 Conv2d、BatchNormal、MaxPool算法没问题!

从做这件事开始,我就担心我手写的算法可能会有问题,不论是功能上还是精度上,不过在我采用了 float64 数据类型后,精度问题不存在了。

而上面和官方的结果的对比验证,也证实了手写的算法在功能上没问题。

反而是我从来没担心的网络结构搭建上,却恰恰出了问题——残差结构有错误,不过既然定位到了,后面有时间再继续调试吧。

最起码,今天离我手写的神经网络出猫,又近了一步。

本文作者原创,请勿随意转载

相关文章
|
8月前
|
算法 文件存储 计算机视觉
【YOLOv8改进】MobileNetV3替换Backbone (论文笔记+引入代码)
YOLO目标检测专栏探讨了MobileNetV3的创新改进,该模型通过硬件感知的NAS和NetAdapt算法优化,适用于手机CPU。引入的新架构包括反转残差结构和线性瓶颈层,提出高效分割解码器LR-ASPP,提升了移动设备上的分类、检测和分割任务性能。MobileNetV3-Large在ImageNet上准确率提升3.2%,延迟降低20%,COCO检测速度增快25%。MobileNetV3-Small则在保持相近延迟下,准确率提高6.6%。此外,还展示了MobileNetV3_InvertedResidual模块的代码实现。
|
机器学习/深度学习 计算机视觉 异构计算
Darknet53详细原理(含torch版源码)
Darknet53详细原理(含torch版源码)—— cifar10
494 0
Darknet53详细原理(含torch版源码)
|
机器学习/深度学习 编解码
MobileNetV1详细原理(含torch源码)
MobilenetV1(含torch源码)—— cifar10
397 0
MobileNetV1详细原理(含torch源码)
|
8月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
【Keras+计算机视觉+Tensorflow】DCGAN对抗生成网络在MNIST手写数据集上实战(附源码和数据集 超详细)
【Keras+计算机视觉+Tensorflow】DCGAN对抗生成网络在MNIST手写数据集上实战(附源码和数据集 超详细)
147 0
|
8月前
|
机器学习/深度学习 文字识别 算法
【Keras计算机视觉OCR】文字识别算法中DenseNet、LSTM、CTC、Attention的讲解(图文解释 超详细)
【Keras计算机视觉OCR】文字识别算法中DenseNet、LSTM、CTC、Attention的讲解(图文解释 超详细)
314 0
|
机器学习/深度学习 计算机视觉 异构计算
MobileNetV2详细原理(含torch源码)
MobileNetV2详细原理(含torch源码)—— cifar10
522 0
MobileNetV2详细原理(含torch源码)
|
机器学习/深度学习 存储 编解码
MobileNetV3详细原理(含torch源码)
MobilneNetV3详细原理(含torch源码)—— cifar10
821 0
MobileNetV3详细原理(含torch源码)
|
PyTorch 算法框架/工具
GoogLeNet InceptionV1代码复现+超详细注释(PyTorch)
GoogLeNet InceptionV1代码复现+超详细注释(PyTorch)
376 1
|
PyTorch 算法框架/工具 机器学习/深度学习
GoogLeNet InceptionV3代码复现+超详细注释(PyTorch)
GoogLeNet InceptionV3代码复现+超详细注释(PyTorch)
453 0
|
机器学习/深度学习 PyTorch 算法框架/工具
使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 下
使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 下