定义ModuleList
我们可以将我们需要的层放入到一个集合中,然后将这个集合作为参数传入nn.ModuleList中,但是这个子类并不可以直接使用,因为这个子类并没有实现forward函数,所以要使用还需要放在继承了nn.Module的模型中进行使用。
model_list = nn.ModuleList([nn.Conv2d(1, 5, 2), nn.Linear(10, 2), nn.Sigmoid()]) x = torch.randn(32, 3, 24, 24) for model in model_list: model_list(x)
使用ModuleList定义网络
class Net(nn.Module): def __init__(self): super().__init__() self.model_list = nn.ModuleList([nn.Conv2d(1, 5, 2), nn.Linear(10, 2), nn.Sigmoid()]) def forward(self, x): return self.model_list(x)
打印网络层结构
model = Net() print(model)
Net( (model_list): ModuleList( (0): Conv2d(1, 5, kernel_size=(2, 2), stride=(1, 1)) (1): Linear(in_features=10, out_features=2, bias=True) (2): Sigmoid() ) )