AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比

简介: AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比

一、引言

本文是上一篇AI智能体研发之路-模型篇(四):一文入门pytorch开发的番外篇,对上文中pytorch的网络结构和tensorflow的模型结构部分进一步详细对比与说明(水一篇为了得到当天的流量卷哈哈,如果想更详细的了解pytorch,辛苦移步上一篇哈。

二、pytorch模型结构定义

def __init__(self, input_size, hidden_size, output_size):
        super(ThreeLayerDNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)  # 第一层全连接层
        self.fc2 = nn.Linear(hidden_size, hidden_size)  # 第二层全连接层
        self.fc3 = nn.Linear(hidden_size, output_size)  # 输出层
        self.sigmoid = nn.Sigmoid()

首先定义了一个名为`ThreeLayerDNN`的类,它是基于PyTorch框架的,用于构建一个具有三个全连接层(也称为密集层)的深度神经网络,特别适用于二分类问题。下面是对代码的详细解释:

  • `__init__`: 这是Python中的构造函数,当创建`ThreeLayerDNN`类的新实例时会被调用。
  • `super(ThreeLayerDNN, self).__init__()`: 这行代码调用父类的初始化方法。因为`ThreeLayerDNN`继承自PyTorch的`nn.Module`类,这一步确保了`ThreeLayerDNN`具有`nn.Module`的所有基本属性和方法。
  • `self.fc1 = nn.Linear(input_size, hidden_size)`: 这里定义了神经网络的第一层全连接层(fully connected layer)。`input_size`是输入数据的特征数量,`hidden_size`是这一层的神经元数量。全连接层意味着输入数据的每个特征都将与这一层的每个神经元相连接。
  • `self.fc2 = nn.Linear(hidden_size, hidden_size)`: 定义了第二层全连接层,结构与第一层相同,保持了相同的隐藏层大小,这在某些架构中用于加深网络而不立即增加模型复杂度。
  • `self.fc3 = nn.Linear(hidden_size, output_size)`: 这是网络的输出层,其输入大小与隐藏层相同,输出大小为`output_size`,对于二分类问题,通常为1。
  • `self.sigmoid = nn.Sigmoid()`: 这行代码定义了一个Sigmoid激活函数,它将在网络的输出层之后应用。Sigmoid函数将输出映射到(0, 1)之间,非常适合二分类问题,其中输出可以解释为属于正类的概率。

综上所述,这段代码构建了一个基础的神经网络结构,适合进行二分类任务,通过全连接层提取特征,并使用Sigmoid函数将网络输出转换为概率估计。

三、tensorflow模型结构定义

model = Sequential([
    Dense(512, input_shape=(X_train.shape[1],)),  # 第一层
    Activation('relu'),
    Dense(512),  # 第二层
    Activation('relu'),
    Dense(1),  # 输出层
    Activation('sigmoid')  # 二分类使用sigmoid
])

使用Keras库(现在是TensorFlow的一个部分)定义了一个简单的深度学习模型,具体来说是一个顺序(Sequential)模型,适用于进行二分类任务。下面是对这段代码的详细解释:

  • Sequential模型: 这是一种线性堆叠层的模型,适合于简单的前向传播神经网络。
  • Dense层: 也称为全连接层,每个神经元都与前一层的所有神经元相连。
  • Dense(512, input_shape=(X_train.shape[1],)): 第一层,有512个神经元,input_shape=(X_train.shape[1],)指定了输入数据的形状,这里假设X_train是一个二维数组,其中每一行是一个样本,X_train.shape[1]表示每个样本的特征数量。
  • Dense(512): 第二层,同样有512个神经元,由于是在Sequential模型中,它自动接收前一层的输出作为输入。
  • Dense(1): 输出层,只有一个神经元,适用于二分类问题。
  • Activation层: 激活函数层,为神经网络引入非线性。
  • Activation('relu'): 使用ReLU(Rectified Linear Unit)作为激活函数,它在输入大于0时输出输入值,小于0时输出0,有助于解决梯度消失问题。
  • 最后一层使用Activation('sigmoid'): 二分类任务中,输出层常用sigmoid激活函数,将输出映射到(0, 1)之间,便于解释为概率。

四、总结

两种框架在定义模型结构时思路基本相同,pytorch基于动态图,更加灵活。tensorflow基于静态图,更加稳定。


目录
相关文章
|
7月前
|
机器学习/深度学习 人工智能 安全
AI加速疫苗研发:从十年磨一剑到一年出成果
AI加速疫苗研发:从十年磨一剑到一年出成果
376 27
|
7月前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
669 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
|
6月前
|
边缘计算 人工智能 PyTorch
130_知识蒸馏技术:温度参数与损失函数设计 - 教师-学生模型的优化策略与PyTorch实现
随着大型语言模型(LLM)的规模不断增长,部署这些模型面临着巨大的计算和资源挑战。以DeepSeek-R1为例,其671B参数的规模即使经过INT4量化后,仍需要至少6张高端GPU才能运行,这对于大多数中小型企业和研究机构来说成本过高。知识蒸馏作为一种有效的模型压缩技术,通过将大型教师模型的知识迁移到小型学生模型中,在显著降低模型复杂度的同时保留核心性能,成为解决这一问题的关键技术之一。
568 6
|
7月前
|
人工智能 运维 安全
AI来了,网络安全运维还能靠“人海战术”吗?
AI来了,网络安全运维还能靠“人海战术”吗?
339 28
|
7月前
|
机器学习/深度学习 资源调度 算法框架/工具
AI-ANNE: 将神经网络迁移到微控制器的深度探索——论文阅读
AI-ANNE框架探索将深度学习模型迁移至微控制器的可行路径,基于MicroPython在Raspberry Pi Pico上实现神经网络核心组件,支持本地化推理,推动TinyML在边缘设备中的应用。
433 10
|
6月前
|
机器学习/深度学习 数据采集 人工智能
深度学习实战指南:从神经网络基础到模型优化的完整攻略
🌟 蒋星熠Jaxonic,AI探索者。深耕深度学习,从神经网络到Transformer,用代码践行智能革命。分享实战经验,助你构建CV、NLP模型,共赴二进制星辰大海。
|
7月前
|
机器学习/深度学习 传感器 算法
【无人车路径跟踪】基于神经网络的数据驱动迭代学习控制(ILC)算法,用于具有未知模型和重复任务的非线性单输入单输出(SISO)离散时间系统的无人车的路径跟踪(Matlab代码实现)
【无人车路径跟踪】基于神经网络的数据驱动迭代学习控制(ILC)算法,用于具有未知模型和重复任务的非线性单输入单输出(SISO)离散时间系统的无人车的路径跟踪(Matlab代码实现)
484 2
|
7月前
|
机器学习/深度学习 并行计算 算法
【CPOBP-NSWOA】基于豪冠猪优化BP神经网络模型的多目标鲸鱼寻优算法研究(Matlab代码实现)
【CPOBP-NSWOA】基于豪冠猪优化BP神经网络模型的多目标鲸鱼寻优算法研究(Matlab代码实现)
180 8
|
7月前
|
人工智能 安全 网络安全
从不确定性到确定性,“动态安全+AI”成网络安全破题密码
2025年国家网络安全宣传周以“网络安全为人民,靠人民”为主题,聚焦AI安全、个人信息保护等热点。随着AI技术滥用加剧,智能化攻击频发,瑞数信息推出“动态安全+AI”防护体系,构建“三层防护+两大闭环”,实现风险前置识别与全链路防控,助力企业应对新型网络威胁,筑牢数字时代安全防线。(238字)
461 1
|
7月前
|
人工智能 监控 数据可视化
如何破解AI推理延迟难题:构建敏捷多云算力网络
本文探讨了AI企业在突破算力瓶颈后,如何构建高效、稳定的网络架构以支撑AI产品化落地。文章分析了典型AI IT架构的四个层次——流量接入层、调度决策层、推理服务层和训练算力层,并深入解析了AI架构对网络提出的三大核心挑战:跨云互联、逻辑隔离与业务识别、网络可视化与QoS控制。最终提出了一站式网络解决方案,助力AI企业实现多云调度、业务融合承载与精细化流量管理,推动AI服务高效、稳定交付。

热门文章

最新文章

推荐镜像

更多