问题
我在使用pytorch的 LSTM (RNN) 构建多类文本分类网络时遇到此错误,网络结构没有问题,能够运行起来,但是运行到几个batch后就报错Expected hidden[0] size (2, 136, 256), got [2, 256, 256]
分析
该错误是由于的训练数据不能被批量大小整除造成的。前面的batch都是256个,但是最后一个batch不足256,只有136个。
假设训练数据有 100个,batch大小为 16,划分为6个batch,最后一个batch将只有 4 个(100%16 = 4)个。
解决方案
(1)方法一
修改batchsize,让数据集大小能整除batchsize
(2)方法二
如果使用Dataloader,设置一个参数drop_last=True,会自动舍弃最后不足batchsize的batch
from torch.utils.data import DataLoader
train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size, drop_last=True)