Keras RNN 源码分析

简介: 在 keras 源码中, layers/recurrent.py 中看到 RNN 实现方式RNN 中的循环体使用 RNNCell 来进行定义的,在 RNN(Layer) 中的 compute_output_shape 函数可以查看到 RNN 输出维度的计算方法, 可以看出维度为 (输入维度, 输出维度) .

在 keras 源码中, layers/recurrent.py 中看到 RNN 实现方式

RNN 中的循环体使用 RNNCell 来进行定义的,

在 RNN(Layer) 中的 compute_output_shape 函数可以查看到 RNN 输出维度的计算方法, 可以看出维度为 (输入维度, 输出维度) .代码如下:

    def compute_output_shape(self, input_shape):
        if isinstance(input_shape, list):
            input_shape = input_shape[0]

        if hasattr(self.cell.state_size, '__len__'):
            output_dim = self.cell.state_size[0]
        else:
            output_dim = self.cell.state_size

        if self.return_sequences:
            output_shape = (input_shape[0], input_shape[1], output_dim)
        else:
            output_shape = (input_shape[0], output_dim)

        if self.return_state:
            state_shape = [(input_shape[0], output_dim) for _ in self.states]
            return [output_shape] + state_shape
        else:
            return output_shape

其中通过查看 LSTMCell 中的定义内容,如下:

    def __init__(self, units,
                 ....
                 **kwargs):
        super(LSTMCell, self).__init__(**kwargs)
        self.units = units
        self.activation = activations.get(activation)
        self.recurrent_activation = activations.get(recurrent_activation)
        self.use_bias = use_bias
                 ....
        self.dropout = min(1., max(0., dropout))
        self.recurrent_dropout = min(1., max(0., recurrent_dropout))
        self.implementation = implementation
        self.state_size = (self.units, self.units)
        self._dropout_mask = None
        self._recurrent_dropout_mask = None

因此, 对于 LSTMCell 来说输出的 shape 即为 (input_shape[0], units, units), 在代码中可以看到 RNN 是通过 state 来管理当前 RNNLayer 使用哪个 LSTMCell 进行当前计算.

在 RNNLayer 中存在 recurrent_kernel ,该只用来存放再传入下个 state 时使用的 kernel,

目录
相关文章
|
机器学习/深度学习 移动开发 算法
DL之CNN:基于CRNN_OCR算法(keras,CNN+RNN)利用数据集(torch,mdb格式)训练来实现新图片上不定长度字符串进行识别—训练过程
DL之CNN:基于CRNN_OCR算法(keras,CNN+RNN)利用数据集(torch,mdb格式)训练来实现新图片上不定长度字符串进行识别—训练过程
DL之CNN:基于CRNN_OCR算法(keras,CNN+RNN)利用数据集(torch,mdb格式)训练来实现新图片上不定长度字符串进行识别—训练过程
|
机器学习/深度学习 算法 算法框架/工具
DL之CNN:基于CRNN_OCR算法(keras,CNN+RNN)利用数据集(torch,mdb格式)训练来实现新图片上不定长度字符串进行识别—预测过程
DL之CNN:基于CRNN_OCR算法(keras,CNN+RNN)利用数据集(torch,mdb格式)训练来实现新图片上不定长度字符串进行识别—预测过程
DL之CNN:基于CRNN_OCR算法(keras,CNN+RNN)利用数据集(torch,mdb格式)训练来实现新图片上不定长度字符串进行识别—预测过程
|
机器学习/深度学习 TensorFlow 算法框架/工具
|
机器学习/深度学习 算法框架/工具 算法
使用Keras进行深度学习:(五)RNN和双向RNN讲解及实践
欢迎大家关注我们的网站和系列教程:http://www.tensorflownews.com/,学习更多的机器学习、深度学习的知识! 笔者:Ray 介绍 通过对前面文章的学习,对深度神经网络(DNN)和卷积神经网络(CNN)有了一定的了解,也感受到了这些神经网络在各方面的应用都有不错的效果。
2905 0
|
7月前
|
机器学习/深度学习
【从零开始学习深度学习】33.语言模型的计算方式及循环神经网络RNN简介
【从零开始学习深度学习】33.语言模型的计算方式及循环神经网络RNN简介
【从零开始学习深度学习】33.语言模型的计算方式及循环神经网络RNN简介
|
3月前
|
机器学习/深度学习 数据采集 存储
时间序列预测新突破:深入解析循环神经网络(RNN)在金融数据分析中的应用
【10月更文挑战第7天】时间序列预测是数据科学领域的一个重要课题,特别是在金融行业中。准确的时间序列预测能够帮助投资者做出更明智的决策,比如股票价格预测、汇率变动预测等。近年来,随着深度学习技术的发展,尤其是循环神经网络(Recurrent Neural Networks, RNNs)及其变体如长短期记忆网络(LSTM)和门控循环单元(GRU),在处理时间序列数据方面展现出了巨大的潜力。本文将探讨RNN的基本概念,并通过具体的代码示例展示如何使用这些模型来进行金融数据分析。
488 2
|
7月前
|
机器学习/深度学习 自然语言处理 算法
RNN-循环神经网络
自然语言处理(Nature language Processing, NLP)研究的主要是通过计算机算法来理解自然语言。对于自然语言来说,处理的数据主要就是人类的语言,我们在进行文本数据处理时,需要将文本进行数据值化,然后进行后续的训练工作。
|
7月前
|
机器学习/深度学习 自然语言处理 算法
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN
|
8月前
|
机器学习/深度学习 自然语言处理 语音技术
深度学习500问——Chapter06: 循环神经网络(RNN)(3)
深度学习500问——Chapter06: 循环神经网络(RNN)(3)
162 3
|
8月前
|
机器学习/深度学习 自然语言处理 PyTorch
使用Python实现循环神经网络(RNN)的博客教程
使用Python实现循环神经网络(RNN)的博客教程
842 1

热门文章

最新文章