目录
1、tf.contrib.rnn.DropoutWrapper函数解读与理解
2、tf.contrib.rnn.MultiRNNCell函数解读与理解
tensorflow官网API文档:https://tensorflow.google.cn/api_docs
1、tf.contrib.rnn.DropoutWrapper函数解读与理解
在机器学习的模型中,如果模型的参数太多,而训练样本又太少,训练出来的模型很容易产生过拟合的现象。在训练神经网络的时候经常会遇到过拟合的问题。过拟合具体表现在:模型在训练数据上损失函数较小,预测准确率较高;但是在测试数据上损失函数比较大,预测准确率较低。
机器学习模型训练中,过拟合现象实在令人头秃。而 2012 年 Geoffrey Hinton 提出的 Dropout 对防止过拟合有很好的效果。之后大量 Dropout 变体涌现,这项技术也成为机器学习研究者常用的训练 trick。万万没想到的是,谷歌为该项技术申请了专利,而且这项专利已经正式生效,2019-06-26 专利生效,2034-09-03 专利到期!
Dropout,指在神经网络中,每个神经单元在每次有数据流入时,以一定的概率keep_prob
正常工作,否则输出0值。这是一种有效的正则化方法,可以有效降低过拟合。在RNN中进行dropout时,对于RNN的部分不进行dropout,也就是说从t-1时候的状态传递到t时刻进行计算时,这个中间不进行memory的dropout;仅在同一个t时刻中,多层cell之间传递信息的时候进行dropout。在RNN中,这里的dropout是在输入,输出,或者不用的循环层之间使用,或者全连接层,不会在同一层的循环体中使用。
1.1、源代码解读
Operator adding dropout to inputs and outputs of the given cell. | 操作者将dropout添加到给定单元的输入和输出。 |
1. tf.compat.v1.nn.rnn_cell.DropoutWrapper( 2. *args, **kwargs 3. ) |
|
Args:
|
参数:
|
Methods
1. get_initial_state( 2. inputs=None, batch_size=None, dtype=None 3. )
1. zero_state( 2. batch_size, dtype 3. ) |
1.2、案例应用
相关文章:TF之LSTM:利用多层LSTM算法对MNIST手写数字识别数据集进行多分类
1. 2. lstm_cell = rnn.BasicLSTMCell(num_units=hidden_size, forget_bias=1.0, state_is_tuple=True) #定义一层 LSTM_cell,只需要说明 hidden_size, 它会自动匹配输入的 X 的维度 3. lstm_cell = rnn.DropoutWrapper(cell=lstm_cell, input_keep_prob=1.0, output_keep_prob=keep_prob) #添加 dropout layer, 一般只设置 output_keep_prob
2、tf.contrib.rnn.MultiRNNCell函数解读与理解
2.1、源代码解读
RNN cell composed sequentially of multiple simple cells. | RNN细胞由多个简单细胞依次组成。 |
1. tf.compat.v1.nn.rnn_cell.MultiRNNCell( 2. cells, state_is_tuple=True 3. ) |
|
Args:
|
参数: 单元格:按此顺序组成的RNNCells列表。 state_is_tuple:如果为真,则接受状态和返回状态为n元组,其中n = len(cell)。如果为假,则所有状态都沿着列轴连接。后一种行为很快就会被摒弃。 |
Methods
1. get_initial_state( 2. inputs=None, batch_size=None, dtype=None 3. )
1. zero_state( 2. batch_size, dtype 3. ) 4. Return zero-filled state tensor(s). Args:
|
|
Returns: If If |
返回 如果state_size是一个int或TensorShape,那么返回值就是一个包含0的shape [batch_size, state_size]的N-D张量。 如果state_size是一个嵌套列表或元组,那么返回值就是一个嵌套列表或元组(具有相同结构)的2-张量,其中每个s的形状[batch_size, s]为state_size中的每个s。 |
2.2、案例应用
相关文章:DL之LSTM:LSTM算法论文简介(原理、关键步骤、RNN/LSTM/GRU比较、单层和多层的LSTM)、案例应用之详细攻略
1. num_units = [128, 64] 2. cells = [BasicLSTMCell(num_units=n) for n in num_units] 3. stacked_rnn_cell = MultiRNNCell(cells)