Keras Layer自定义

简介: 简单层实现实现一个简单层需要首先继承 layers.Layer 类即可,如下是官方网站上的例子:from keras import backend as Kfrom keras.

简单层实现

实现一个简单层需要首先继承 layers.Layer 类即可,如下是官方网站上的例子:

from keras import backend as K
from keras.engine.topology import Layer
import numpy as np

class MyLayer(Layer):

    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(MyLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        self.kernel = self.add_weight(name='kernel', 
                                      shape=(input_shape[1], self.output_dim),
                                      initializer='uniform',
                                      trainable=True)
        super(MyLayer, self).build(input_shape)  # Be sure to call this somewhere!

    def call(self, x):
        return K.dot(x, self.kernel)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)

如上所示, 其中有三个函数需要我们自己实现:

  • build() 用来初始化定义weights, 这里可以用父类的self.add_weight() 函数来初始化数据, 该函数必须将 self.built 设置为True, 以保证该 Layer 已经成功 build , 通常如上所示, 使用 super(MyLayer, self).build(input_shape) 来完成
  • call() 用来执行 Layer 的职能, 即当前 Layer 所有的计算过程均在该函数中完成
  • compute_output_shape() 用来计算输出张量的 shape

正常DL都是一个forward, backword, update 三个流程,而在 keras 中对于单层 Layer 来说,通过将可训练的权应该在这里被加入列表`self.trainable_weights中。其他的属性还包括self.non_trainabe_weights(列表)和self.updates(需要更新的形如(tensor, new_tensor)的tuple的列表)。你可以参考BatchNormalization层的实现来学习如何使用上面两个属性。这个方法必须设置self.built = True,可通过调用super([layer],self).build()实现

loss 以及参数更新

详细查看了下 add_weight 函数实现如下(keras/engine/topology.py):

    def add_weight(self,
                   name,
                   shape,
                   dtype=None,
                   initializer=None,
                   regularizer=None,
                   trainable=True,
                   constraint=None):
        """Adds a weight variable to the layer.
        # Arguments
            name: String, the name for the weight variable.
            shape: The shape tuple of the weight.
            dtype: The dtype of the weight.
            initializer: An Initializer instance (callable).
            regularizer: An optional Regularizer instance.
            trainable: A boolean, whether the weight should
                be trained via backprop or not (assuming
                that the layer itself is also trainable).
            constraint: An optional Constraint instance.
        # Returns
            The created weight variable.
        """
        initializer = initializers.get(initializer)
        if dtype is None:
            dtype = K.floatx()
        weight = K.variable(initializer(shape),
                            dtype=dtype,
                            name=name,
                            constraint=constraint)
        if regularizer is not None:
            self.add_loss(regularizer(weight))
        if trainable:
            self._trainable_weights.append(weight)
        else:
            self._non_trainable_weights.append(weight)
        return weight

从上述代码来看通过 add_weight 创建的参数,通过 regularizer 函数来计算 loss, 如果 trainable 设置 True ,则该生成的 self._trainable_weights, 可以通过 regularizer 来构建 loss

具体训练过程参见: keras/engine/training.py

目录
相关文章
|
25天前
|
人工智能 缓存 小程序
微信小游戏开发的方法
微信小游戏成中国最大创业风口!2026年“AI小程序成长计划”落地,支持混元大模型深度集成,涵盖智能NPC、AI生成内容等。Cocos/Unity/LayaAir多引擎适配,4MB首包限制、社交裂变与真机调试为关键要点。(239字)
|
Java
Java 面向对象编程的三大法宝:封装、继承与多态
本文介绍了Java面向对象编程中的三大核心概念:封装、继承和多态。
664 15
|
机器学习/深度学习 人工智能 自然语言处理
【人工智能】人工智能就业岗位发展方向有哪些?
人工智能领域的岗位多样,涵盖了从技术研发到应用实施、从产品设计到市场运营等各个方面,以下是人工智能就业岗位的主要发展方向
1634 59
|
8月前
|
存储 缓存 数据挖掘
阿里云服务器实例选购指南:经济型、通用算力型、计算型、通用型、内存型性能与适用场景解析
当我们在通过阿里云的活动页面挑选云服务器时,相同配置的云服务器通常会有多种不同的实例供我们选择,并且它们之间的价格差异较为明显。这是因为不同实例规格所采用的处理器存在差异,其底层架构也各不相同,比如常见的X86计算架构和Arm计算架构。正因如此,不同实例的云服务器在性能表现以及适用场景方面都各有特点。为了帮助大家在众多实例中做出更合适的选择,本文将针对阿里云服务器的经济型、通用算力型、计算型、通用型和内存型实例,介绍它们的性能特性以及对应的使用场景,以供大家参考和选择。
|
存储 机器学习/深度学习 分布式计算
【DSW Gallery】COMMON_IO使用指南
COMMON_IO模块提供了TableReader和TableWriter两个接口,使用TableReader可以读取ODPS Table中的数据,使用TableWriter可以将数据写入ODPS Table。
【DSW Gallery】COMMON_IO使用指南
|
Java 关系型数据库 MySQL
基于SpringBoot+Vue交流和分享平台的设计与实现(源码+部署说明+演示视频+源码介绍)(1)
基于SpringBoot+Vue交流和分享平台的设计与实现(源码+部署说明+演示视频+源码介绍)
687 1
|
机器学习/深度学习 数据采集 供应链
使用Python实现智能食品价格预测的深度学习模型
使用Python实现智能食品价格预测的深度学习模型
402 6
|
域名解析 网络协议
非阿里云注册域名如何在云解析DNS设置解析?
非阿里云注册域名如何在云解析DNS设置解析?
|
Java Spring NoSQL
Spring Boot 环境变量读取 和 属性对象的绑定
版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/catoop/article/details/50548009 凡是被Spring管理的类,实现接口 EnvironmentAware 重写方法 setEnvironment 可以在工程启动时,获取到系统环境变量和application配置文件中的变量。
2712 0