【MindSpore深度学习框架】MindSpore中的Cell模块

简介: 【MindSpore深度学习框架】MindSpore中的Cell模块

欢迎回到MindSpore神经网络编程系列。在这篇文章中,我们将通过MindSpore来讲解下该框架中的Cell模块。废话不多说,我们开始吧。


一、概述

MindSpore的Cell类是构建所有网络的基类,也是网络的基本单元。当用户需要自定义网络时,需要继承Cell类,并重写__init__方法和construct方法。

损失函数、优化器和模型层等本质上也属于网络结构,也需要继承Cell类才能实现功能,同样用户也可以根据业务需求自定义这部分内容。

本节内容介绍Cell类的关键成员函数,“构建网络”中将介绍基于Cell实现的MindSpore内置损失函数、优化器和模型层及使用方法,以及通过实例介绍如何利用Cell类构建自定义网络。

  • _init_:在该函数中定义网络所需要的层或者变量等信息
  • construct:该函数实现网络的执行流程

二、ops构建网络

这里首先我们使用了ops中的Conv2D算子,然后又使用了bias_add算子,同时定义下模型的权重,之后在construct函数中定义网络的执行流程。

代码样例如下:

class Net(nn.Cell):
    def __init__(self,in_channels=10,out_channels=20,kernel_size=3):
        super(Net,self).__init__()
        self.conv2d=ops.Conv2D(out_channels,kernel_size)
        self.bias_add=ops.BiasAdd()
        self.weight=Parameter(initializer('normal',[out_channels,in_channels,kernel_size,kernel_size]),name='conv.weight')
    def construct(self,x):
        output=self.conv2d(x,self.weight)
        output=self.bias_add(output,self.bias)
        return output

三、nn构建网络

对于nn模块构建网络,非常的方便,它是mindSpore封装的高阶API,简单调用。

代码样例如下:

class Net(nn.Cell):
    def __init__(self,in_channels=10,out_channels=20,kernel_size=3):
        super(Net,self).__init__()
        self.conv2d=nn.Conv2d(in_channels,out_channels,kernel_size,has_bias=True,weight_init=Normal(0.02))
    def construct(self,x):
        output=self.conv2d(x)
        return output

四、nn模块和ops的关系

MindSpore的nn模块是Python实现的模型组件,是对低阶API的封装,主要包括各种模型层、损失函数、优化器等。

同时nn也提供了部分与Primitive算子同名的接口,主要作用是对Primitive算子进行进一步封装,为用户提供更友好的API。

重新分析上文介绍construct方法的用例,此用例是MindSpore的nn.Conv2d源码简化内容,内部会调用ops.Conv2Dnn.Conv2d卷积API增加输入参数校验功能并判断是否bias等,是一个高级封装的模型层。

五、网络的常用方法

1.parameters_dict()

该方法会以字典的形式返回网络的所有参数,键为参数的名称,值为对应的参数

class Net(nn.Cell):
    def __init__(self,in_channels=10,out_channels=20,kernel_size=3):
        super(Net,self).__init__()
        self.conv2d=ops.Conv2D(out_channels,kernel_size)
        self.bias_add=ops.BiasAdd()
        self.weight=Parameter(initializer('normal',[out_channels,in_channels,kernel_size,kernel_size]),name='conv.weight')
    def construct(self,x):
        output=self.conv2d(x,self.weight)
        output=self.bias_add(output,self.bias)
        return output
net=Net()
net.parameters_dict()
>>>OrderedDict([('conv.weight',
              Parameter (name=conv.weight, shape=(20, 10, 3, 3), dtype=Float32, requires_grad=True))])

2.get_parameters()

该方法返回一个迭代器,返回的是模型的参数,就是返回上个方法的所以值

iter=net.get_parameters()
next(iter)
>>>Parameter (name=conv2d.weight, shape=(20, 10, 3, 3), dtype=Float32, requires_grad=True)

3.name_cells()

返回网络中所有单元格的迭代器

net.name_cells()
>>>OrderedDict([('conv2d',
              Conv2d<input_channels=10, output_channels=20, kernel_size=(3, 3), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.Normal object at 0x0000023FABF21248>, bias_init=zeros, format=NCHW>)])

4.cells_and_names()

返回网络中所有单元格的迭代器,包括单元格的名称和它本身

注意第一个返回的是整个网络,键对应着空

names=[]
for m in net.cells_and_names():
    print(m)
    names.append(m[0]) if m[0] else None
print('-------names-------')
print(names)
('', Net<
  (conv2d): Conv2d<input_channels=10, output_channels=20, kernel_size=(3, 3), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.Normal object at 0x0000023FABF21248>, bias_init=zeros, format=NCHW>
  >)
('conv2d', Conv2d<input_channels=10, output_channels=20, kernel_size=(3, 3), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.Normal object at 0x0000023FABF21248>, bias_init=zeros, format=NCHW>)
-------names-------
['conv2d']

5.cells()

返回对直接单元格的迭代器

net.cells()
>>>odict_values([Conv2d<input_channels=10, output_channels=20, kernel_size=(3, 3), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.Normal object at 0x0000023FABF21248>, bias_init=zeros, format=NCHW>])

上面的几个方法几乎差不多,按照自己喜好进行选择使用


目录
相关文章
|
8月前
|
机器学习/深度学习 API 语音技术
|
4天前
|
机器学习/深度学习 存储 人工智能
MNN:阿里开源的轻量级深度学习推理框架,支持在移动端等多种终端上运行,兼容主流的模型格式
MNN 是阿里巴巴开源的轻量级深度学习推理框架,支持多种设备和主流模型格式,具备高性能和易用性,适用于移动端、服务器和嵌入式设备。
43 18
MNN:阿里开源的轻量级深度学习推理框架,支持在移动端等多种终端上运行,兼容主流的模型格式
|
2月前
|
机器学习/深度学习 监控 PyTorch
深度学习工程实践:PyTorch Lightning与Ignite框架的技术特性对比分析
在深度学习框架的选择上,PyTorch Lightning和Ignite代表了两种不同的技术路线。本文将从技术实现的角度,深入分析这两个框架在实际应用中的差异,为开发者提供客观的技术参考。
53 7
|
2月前
|
机器学习/深度学习 自然语言处理 并行计算
DeepSpeed分布式训练框架深度学习指南
【11月更文挑战第6天】随着深度学习模型规模的日益增大,训练这些模型所需的计算资源和时间成本也随之增加。传统的单机训练方式已难以应对大规模模型的训练需求。
197 3
|
5月前
|
机器学习/深度学习 算法 TensorFlow
深入探索强化学习与深度学习的融合:使用TensorFlow框架实现深度Q网络算法及高效调试技巧
【8月更文挑战第31天】强化学习是机器学习的重要分支,尤其在深度学习的推动下,能够解决更为复杂的问题。深度Q网络(DQN)结合了深度学习与强化学习的优势,通过神经网络逼近动作价值函数,在多种任务中表现出色。本文探讨了使用TensorFlow实现DQN算法的方法及其调试技巧。DQN通过神经网络学习不同状态下采取动作的预期回报Q(s,a),处理高维状态空间。
78 1
|
8月前
|
机器学习/深度学习 PyTorch 算法框架/工具
深度学习框架:Pytorch与Keras的区别与使用方法
深度学习框架:Pytorch与Keras的区别与使用方法
|
5月前
|
测试技术 数据库
探索JSF单元测试秘籍!如何让您的应用更稳固、更高效?揭秘成功背后的测试之道!
【8月更文挑战第31天】在 JavaServer Faces(JSF)应用开发中,确保代码质量和可维护性至关重要。本文详细介绍了如何通过单元测试实现这一目标。首先,阐述了单元测试的重要性及其对应用稳定性的影响;其次,提出了提高 JSF 应用可测试性的设计建议,如避免直接访问外部资源和使用依赖注入;最后,通过一个具体的 `UserBean` 示例,展示了如何利用 JUnit 和 Mockito 框架编写有效的单元测试。通过这些方法,不仅能够确保代码质量,还能提高开发效率和降低维护成本。
62 0
|
5月前
|
UED 开发者
哇塞!Uno Platform 数据绑定超全技巧大揭秘!从基础绑定到高级转换,优化性能让你的开发如虎添翼
【8月更文挑战第31天】在开发过程中,数据绑定是连接数据模型与用户界面的关键环节,可实现数据自动更新。Uno Platform 提供了简洁高效的数据绑定方式,使属性变化时 UI 自动同步更新。通过示例展示了基本绑定方法及使用 `Converter` 转换数据的高级技巧,如将年龄转换为格式化字符串。此外,还可利用 `BindingMode.OneTime` 提升性能。掌握这些技巧能显著提高开发效率并优化用户体验。
71 0
|
5月前
|
机器学习/深度学习 PyTorch TensorFlow
深度学习框架之争:全面解析TensorFlow与PyTorch在功能、易用性和适用场景上的比较,帮助你选择最适合项目的框架
【8月更文挑战第31天】在深度学习领域,选择合适的框架至关重要。本文通过开发图像识别系统的案例,对比了TensorFlow和PyTorch两大主流框架。TensorFlow由Google开发,功能强大,支持多种设备,适合大型项目和工业部署;PyTorch则由Facebook推出,强调灵活性和速度,尤其适用于研究和快速原型开发。通过具体示例代码展示各自特点,并分析其适用场景,帮助读者根据项目需求和个人偏好做出明智选择。
118 0
|
6月前
|
机器学习/深度学习 PyTorch TensorFlow
PAI DLC与其他深度学习框架如TensorFlow或PyTorch的异同
PAI DLC与其他深度学习框架如TensorFlow或PyTorch的异同