- 首先导包
Import torch from torch import nn |
- VGG11由8个卷积,三个全连接组成,注意池化只改变特征图大小,不改变通道数
class MyNet(nn.Module): def __init__(self) -> None: super().__init__() #(1)conv3-64 self.conv1 = nn.Conv2d( in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1 #! 不改变特征图的大小 ) #! 池化只改变特征图大小,不改变通道数 self.max_pool_1 = nn.MaxPool2d(2) #(2)conv3-128 self.conv2 = nn.Conv2d( in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1 ) self.max_pool_2 = nn.MaxPool2d(2) #(3) conv3-256,conv3-256 self.conv3_1 = nn.Conv2d( in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1) self.conv3_2 = nn.Conv2d( in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1 ) self.max_pool_3 = nn.MaxPool2d(2) #(4)conv3-512,conv3-512 self.conv4_1 = nn.Conv2d( in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1 ) self.conv4_2 = nn.Conv2d( in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1 ) self.max_pool_4 = nn.MaxPool2d(2) #(5)conv3-512,conv3-512 self.conv5_1 = nn.Conv2d( in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1 ) self.conv5_2 = nn.Conv2d( in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1 ) self.max_pool_5 = nn.MaxPool2d(2) #(6) self.fc1 = nn.Linear(25088,4096) self.fc2 = nn.Linear(4096,4096) self.fc3 = nn.Linear(4096,1000) def forward(self,x): x = self.conv1(x) print(x.shape) x = self.max_pool_1(x) print(x.shape) x = self.conv2(x) print(x.shape) x = self.max_pool_2(x) print(x.shape) x = self.conv3_1(x) print(x.shape) x = self.conv3_2(x) print(x.shape) x = self.max_pool_3(x) print(x.shape) x = self.conv4_1(x) print(x.shape) x = self.conv4_2(x) print(x.shape) x = self.max_pool_4(x) print(x.shape) x = self.conv5_1(x) print(x.shape) x = self.conv5_2(x) print(x.shape) x = self.max_pool_5(x) print(x.shape) x = torch.flatten(x,1) print(x.shape) x = self.fc1(x) print(x.shape) x = self.fc2(x) print(x.shape) out = self.fc3(x) return out |
- 给定x查看最后结果
x = torch.rand(128,3,224,224) net = MyNet() out = net(x) print(out.shape) #torch.Size([128, 1000]) |
|