从零手写Resnet50实战——利用 torch 识别出了虎猫和萨摩耶

简介: 利用 torch 识别出了虎猫和萨摩耶

大家好啊,我是董董灿。

自从前几天手写了一个慢速卷积之后(从零手写Resnet50实战—手写龟速卷积),我便一口气将 Resnet50 中剩下的算法都写完了。

然后,暴力的,按照 Resnet50 的结构,将手写的算法一层层地连接了起来。

out = compute_conv_layer(out, "conv1")
out = compute_bn_layer(out, "bn1")
out = compute_relu_layer(out)
out = compute_maxpool_layer(out)
# layer1 
out = compute_bottleneck(out, "layer1_bottleneck0", down_sample = True)
out = compute_bottleneck(out, "layer1_bottleneck1", down_sample = False)
out = compute_bottleneck(out, "layer1_bottleneck2", down_sample = False)
# layer2
out = compute_bottleneck(out, "layer2_bottleneck0", down_sample = True)
out = compute_bottleneck(out, "layer2_bottleneck1", down_sample = False)
out = compute_bottleneck(out, "layer2_bottleneck2", down_sample = False)
out = compute_bottleneck(out, "layer2_bottleneck3", down_sample = False)
# layer3
out = compute_bottleneck(out, "layer3_bottleneck0", down_sample = True)
out = compute_bottleneck(out, "layer3_bottleneck1", down_sample = False)
out = compute_bottleneck(out, "layer3_bottleneck2", down_sample = False)
out = compute_bottleneck(out, "layer3_bottleneck3", down_sample = False)
out = compute_bottleneck(out, "layer3_bottleneck4", down_sample = False)
out = compute_bottleneck(out, "layer3_bottleneck5", down_sample = False)
# layer4
out = compute_bottleneck(out, "layer4_bottleneck0", down_sample = True)
out = compute_bottleneck(out, "layer4_bottleneck1", down_sample = False)
out = compute_bottleneck(out, "layer4_bottleneck2", down_sample = False)
# avg pool
out = compute_avgpool_layer(out)
# Linear
out = compute_fc_layer(out, "fc")

算法的手写和网络的搭建,没有调用任何第三方库,这也是这个项目的初衷。

相关代码都已经上传至:项目根目录/python/inference.py,项目地址在文章末尾。

试运行

在将网络搭建完的那一刻,迫不及待的我,赶紧运行了这个网络,试图让它能识别出这张猫,结果如我所想——识别错误!

image.png

image.png

我手写手搭的神经网络将这只猫的类别识别成了 bucket——一个水桶!

不过没关系,这不很正常么,谁能确保刚刚手写的一个神经网络,第一次运行就能成功呢?

马斯克不刚刚发射星舰还失败了吗?

识别错误了,那就开始调试。

于是,我快速地使用 torch 搭建了一个官方的 resent50 模型。然后利用这个官方的模型来推理了两张图片,一张是上面的小猫,一张是下面的狗子。

image.png

结果很明显,官方模型推理正确。

image.png

看到推理结果我才意识到,那只狗子是个 Samoyed——萨摩耶。

在确认了官方的模型和算法可以正确地识别出猫咪和狗子之后,我开始了漫长的debug(调试)之路。

开始调试

调试方法很简单,将 torch 官方的 resnet50 计算每一层得出的结果,和我手写的算法计算的结果对比。
在对比的过程中,真的发现了一个问题。

保存的权值数据与算法不匹配

torch 的权值默认是按照 NCHW 的格式存储的,而我手写算法的时候,习惯按照 NHWC 的格式来写,于是,我的第一层卷积就算错了。

于是,在导出权值的时候(从零手写Resnet50实战——权值另存为),额外增加一个 transpose 操作,将 torch 默认的 NCHW 的权值,转置为我手写算法需要的 NHWC 的权值。

然后继续保存到 txt 中。

在将权值的问题修复完之后,重新对比结果,竟出奇地顺畅。

一路绿灯。

从 conv1->bn1->maxpool,再到第一个layer 中的 conv1->bn1->conv2->bn2->conv3->bn3->relu ,甚至第一个layer中下采样中的计算 conv->bn。

这几个地方竟然全部能和官方resnet50计算的结果对的上!

image.png

也就是,在上图中,绿色部分都没问题,红色部分仍然存在计算错误。出错部分刚好是残差结构的加法操作。

不过这个结果确实是惊到我了,这说明——

我手写的 Conv2d、BatchNormal、MaxPool算法没问题!

从做这件事开始,我就担心我手写的算法可能会有问题,不论是功能上还是精度上,不过在我采用了 float64 数据类型后,精度问题不存在了。

而上面和官方的结果的对比验证,也证实了手写的算法在功能上没问题。

反而是我从来没担心的网络结构搭建上,却恰恰出了问题——残差结构有错误,不过既然定位到了,后面有时间再继续调试吧。

最起码,今天离我手写的神经网络出猫,又近了一步。

本文作者原创,请勿随意转载

相关文章
|
XML 前端开发 Java
SpringMVC实现文件下载实践
SpringMVC实现文件下载实践
330 3
|
SQL 消息中间件 分布式计算
Apache Doris 系列: 入门篇-数据导入及查询
Apache Doris 系列: 入门篇-数据导入及查询
2183 0
|
Android开发 开发者
苹果开发者账号申请教程
登陆苹果官网注册账号 点击地址https://developer.apple.com/account/进入苹果官网 如果没有账号可以点击'Create Apple ID'进行账号注册,输入需要的信息后点击'continue'按钮进入网站 因为我已经有账号,所以直接点'Sign In'登陆进入网站 .
10562 0
|
机器学习/深度学习 算法 PyTorch
昇腾910-PyTorch 实现 ResNet50图像分类
本实验基于PyTorch,在昇腾平台上使用ResNet50对CIFAR10数据集进行图像分类训练。内容涵盖ResNet50的网络架构、残差模块分析及训练代码详解。通过端到端的实战讲解,帮助读者理解如何在深度学习中应用ResNet50模型,并实现高效的图像分类任务。实验包括数据预处理、模型搭建、训练与测试等环节,旨在提升模型的准确率和训练效率。
680 54
|
11月前
|
存储 容器 API
鸿蒙特效教程03-水波纹动画效果实现教程
本教程适合HarmonyOS初学者,通过简单到复杂的步骤,一步步实现漂亮的水波纹动画效果。
407 0
鸿蒙特效教程03-水波纹动画效果实现教程
|
Java 关系型数据库 数据库连接
MyBatis-Plus整合SpringBoot及使用
务必记住,随着MyBatis-Plus版本的更新,一些具体的配置和使用方式可能会有所变动。在实际开发过程中,建议参考MyBatis-Plus的官方文档,以获取最新和详细的指导。
832 1
|
人工智能 弹性计算 运维
开启运维新纪元!阿里云OS Copilot深度评测 & 体验分享
OS Copilot是Alibaba Cloud为Linux推出的一款基于大模型的智能助手,它能理解自然语言、辅助命令执行和系统运维。目前仅支持Alibaba Cloud Linux 3的x86_64架构。安装过程涉及线上和本地体验,包括申请试用、配置环境变量、安装组件等步骤。OS Copilot提供命令行和多轮交互模式,能进行代码生成和摘要,辅助开发和运维工作。产品体验评测中,OS Copilot因其自然语言理解和高效辅助得到高度评价,尤其对运维人员来说,能大幅提升工作效率。然而,目前仅限于特定操作系统,是其局限性。未来有望扩展更多功能和支持更多平台。
134265 26
|
算法 Linux 测试技术
Linux编程:测试-高效内存复制与随机数生成的性能
该文探讨了软件工程中的性能优化,重点关注内存复制和随机数生成。文章通过测试指出,`g_memmove`在内存复制中表现出显著优势,比简单for循环快约32倍。在随机数生成方面,`GRand`库在1000万次循环中的效率超过传统`rand()`。文中提供了测试代码和Makefile,建议在性能关键场景中使用`memcpy`、`g_memmove`以及高效的随机数生成库。
|
数据采集 存储 监控
构建高效爬虫系统:设计思路与案例分析
构建高效爬虫系统涉及关键模块如爬虫引擎、链接存储、内容处理器等,以及用户代理池、IP代理池等反反爬策略。评估项目复杂性考虑数据规模、网站结构、反爬虫机制等因素。案例分析展示了电子商务价格比较爬虫的设计,强调了系统模块化、错误处理和合规性的重要性。爬虫技术需要不断进化以应对复杂网络环境的挑战。
660 1
|
网络协议 网络安全 网络性能优化
期末复习【计算机网络】
期末复习【计算机网络】
289 0