pytorch中目前已经实现好了3中循环神经网络,分别是RNN
、GRU
、LSTM
,但是发现在nn模块中还存在RNNCell()
、LSTMCell()
这个模块。
对于循环神经网络常用来处理序列数据
,可以理解为依次处理每个时间片的数据,但是对于Cell层只能够处理序列数据中的一个时间片
的数据,所以要想使用Cell层达到RNN的目的,就需要不断循环处理每个时间片的数据。
下面使用LSTM和LSTMCell这两个模块来示例:
nn.LSTM()
input = torch.randn(10, 32, 100) lstm = nn.LSTM(100, 8, 1) output, _ = lstm(input) print(output.shape)
该段代码定义了输入数据维度为【10,32,100】,批次大小为32,序列长度为10,每个时间片对应的维度为100。
定义了一个LSTM层,输入维度为100,隐藏状态维度为8,只有1层,经过LSTM后得到所有时间片的输出结果,维度为【10,32,8】。
torch.Size([10, 32, 8])
nn.LSTMCell()
要想使用LSTMCell来达到同样效果就需要不断使用这个Cell循环处理每个时间片的数据,然后将每次循环得到的输出结果进行堆叠即可。
input = torch.randn(10, 32, 100) lstm = nn.LSTMCell(100, 8) output = [] for time_data in input: out, _ = lstm(time_data) output.append(out) output = torch.stack(output) print(output.shape)
torch.Size([10, 32, 8])
那么pytorch中已经有了LSTM层,为什么要定义这个LSTMCell层呢?
原因很简单,就是能够提高定义模型的灵活性,可以根据自定义的网络模块来组合调用LSTMCell层。