1. 推导公式
1.1 CNN Params
1.2 CNN Flops
1.3 Linear Params
1.4 Linear Flops
2. 计算方法
2.1 parameters法
model = 你自己的模型,eg:CNN() ResNet() SegNet().... params = list(model.parameters()) k = 0 for i in params: l = 1 print("该层的结构:" + str(list(i.size()))) for j in i.size(): l *= j print("该层参数和:" + str(l)) k = k + l print("总参数数量和:" + str(k))
2.2 thop法
thop是一个库,可以通过pip install thop进行安装,安装后通过里面的profile可以获取参数量params以及计算量flops
from thop import profile model = 自己的模型实例化 flops, params = profile(model, inputs=传入模型的输入shape,这个必须要填)
import torch from torchvision.models import resnet18 from thop import profile model = resnet18() input = torch.randn(1, 3, 128, 128) flops, params = profile(model, inputs=(input, )) print('flops:{}'.format(flops)) print('params:{}'.format(params))
2.3 torchstat法
还是一样,没有的先pip install torchstat即可,还是以restnet18为例:
from torchstat import stat from torchvision.models import resnet18 model = resnet18() stat(model, (3, 224, 224))
[MAdd]: AdaptiveAvgPool2d is not supported! [Flops]: AdaptiveAvgPool2d is not supported! [Memory]: AdaptiveAvgPool2d is not supported! module name input shape output shape params memory(MB) MAdd Flops MemRead(B) MemWrite(B) duration[%] MemR+W(B) 0 conv1 3 224 224 64 112 112 9408.0 3.06 235,225,088.0 118,013,952.0 639744.0 3211264.0 6.25% 3851008.0 1 bn1 64 112 112 64 112 112 128.0 3.06 3,211,264.0 1,605,632.0 3211776.0 3211264.0 1.09% 6423040.0 2 relu 64 112 112 64 112 112 0.0 3.06 802,816.0 802,816.0 3211264.0 3211264.0 0.28% 6422528.0 3 maxpool 64 112 112 64 56 56 0.0 0.77 1,605,632.0 802,816.0 3211264.0 802816.0 5.61% 4014080.0 4 layer1.0.conv1 64 56 56 64 56 56 36864.0 0.77 231,010,304.0 115,605,504.0 950272.0 802816.0 4.61% 1753088.0 5 layer1.0.bn1 64 56 56 64 56 56 128.0 0.77 802,816.0 401,408.0 803328.0 802816.0 0.25% 1606144.0 6 layer1.0.relu 64 56 56 64 56 56 0.0 0.77 200,704.0 200,704.0 802816.0 802816.0 0.09% 1605632.0 7 layer1.0.conv2 64 56 56 64 56 56 36864.0 0.77 231,010,304.0 115,605,504.0 950272.0 802816.0 3.72% 1753088.0 8 layer1.0.bn2 64 56 56 64 56 56 128.0 0.77 802,816.0 401,408.0 803328.0 802816.0 0.21% 1606144.0 9 layer1.1.conv1 64 56 56 64 56 56 36864.0 0.77 231,010,304.0 115,605,504.0 950272.0 802816.0 3.63% 1753088.0 10 layer1.1.bn1 64 56 56 64 56 56 128.0 0.77 802,816.0 401,408.0 803328.0 802816.0 0.20% 1606144.0 11 layer1.1.relu 64 56 56 64 56 56 0.0 0.77 200,704.0 200,704.0 802816.0 802816.0 0.09% 1605632.0 12 layer1.1.conv2 64 56 56 64 56 56 36864.0 0.77 231,010,304.0 115,605,504.0 950272.0 802816.0 3.68% 1753088.0 13 layer1.1.bn2 64 56 56 64 56 56 128.0 0.77 802,816.0 401,408.0 803328.0 802816.0 0.20% 1606144.0 14 layer2.0.conv1 64 56 56 128 28 28 73728.0 0.38 115,505,152.0 57,802,752.0 1097728.0 401408.0 3.13% 1499136.0 15 layer2.0.bn1 128 28 28 128 28 28 256.0 0.38 401,408.0 200,704.0 402432.0 401408.0 0.19% 803840.0 16 layer2.0.relu 128 28 28 128 28 28 0.0 0.38 100,352.0 100,352.0 401408.0 401408.0 0.07% 802816.0 17 layer2.0.conv2 128 28 28 128 28 28 147456.0 0.38 231,110,656.0 115,605,504.0 991232.0 401408.0 4.24% 1392640.0 18 layer2.0.bn2 128 28 28 128 28 28 256.0 0.38 401,408.0 200,704.0 402432.0 401408.0 0.19% 803840.0 19 layer2.0.downsample.0 64 56 56 128 28 28 8192.0 0.38 12,744,704.0 6,422,528.0 835584.0 401408.0 1.59% 1236992.0 20 layer2.0.downsample.1 128 28 28 128 28 28 256.0 0.38 401,408.0 200,704.0 402432.0 401408.0 0.22% 803840.0 21 layer2.1.conv1 128 28 28 128 28 28 147456.0 0.38 231,110,656.0 115,605,504.0 991232.0 401408.0 3.54% 1392640.0 22 layer2.1.bn1 128 28 28 128 28 28 256.0 0.38 401,408.0 200,704.0 402432.0 401408.0 0.19% 803840.0 23 layer2.1.relu 128 28 28 128 28 28 0.0 0.38 100,352.0 100,352.0 401408.0 401408.0 0.07% 802816.0 24 layer2.1.conv2 128 28 28 128 28 28 147456.0 0.38 231,110,656.0 115,605,504.0 991232.0 401408.0 3.50% 1392640.0 25 layer2.1.bn2 128 28 28 128 28 28 256.0 0.38 401,408.0 200,704.0 402432.0 401408.0 0.17% 803840.0 26 layer3.0.conv1 128 28 28 256 14 14 294912.0 0.19 115,555,328.0 57,802,752.0 1581056.0 200704.0 3.33% 1781760.0 27 layer3.0.bn1 256 14 14 256 14 14 512.0 0.19 200,704.0 100,352.0 202752.0 200704.0 0.17% 403456.0 28 layer3.0.relu 256 14 14 256 14 14 0.0 0.19 50,176.0 50,176.0 200704.0 200704.0 0.08% 401408.0 29 layer3.0.conv2 256 14 14 256 14 14 589824.0 0.19 231,160,832.0 115,605,504.0 2560000.0 200704.0 5.48% 2760704.0 30 layer3.0.bn2 256 14 14 256 14 14 512.0 0.19 200,704.0 100,352.0 202752.0 200704.0 0.21% 403456.0 31 layer3.0.downsample.0 128 28 28 256 14 14 32768.0 0.19 12,794,880.0 6,422,528.0 532480.0 200704.0 1.37% 733184.0 32 layer3.0.downsample.1 256 14 14 256 14 14 512.0 0.19 200,704.0 100,352.0 202752.0 200704.0 0.17% 403456.0 33 layer3.1.conv1 256 14 14 256 14 14 589824.0 0.19 231,160,832.0 115,605,504.0 2560000.0 200704.0 4.35% 2760704.0 34 layer3.1.bn1 256 14 14 256 14 14 512.0 0.19 200,704.0 100,352.0 202752.0 200704.0 0.17% 403456.0 35 layer3.1.relu 256 14 14 256 14 14 0.0 0.19 50,176.0 50,176.0 200704.0 200704.0 0.08% 401408.0 36 layer3.1.conv2 256 14 14 256 14 14 589824.0 0.19 231,160,832.0 115,605,504.0 2560000.0 200704.0 3.91% 2760704.0 37 layer3.1.bn2 256 14 14 256 14 14 512.0 0.19 200,704.0 100,352.0 202752.0 200704.0 0.17% 403456.0 38 layer4.0.conv1 256 14 14 512 7 7 1179648.0 0.10 115,580,416.0 57,802,752.0 4919296.0 100352.0 5.84% 5019648.0 39 layer4.0.bn1 512 7 7 512 7 7 1024.0 0.10 100,352.0 50,176.0 104448.0 100352.0 0.21% 204800.0 40 layer4.0.relu 512 7 7 512 7 7 0.0 0.10 25,088.0 25,088.0 100352.0 100352.0 0.09% 200704.0 41 layer4.0.conv2 512 7 7 512 7 7 2359296.0 0.10 231,185,920.0 115,605,504.0 9537536.0 100352.0 9.87% 9637888.0 42 layer4.0.bn2 512 7 7 512 7 7 1024.0 0.10 100,352.0 50,176.0 104448.0 100352.0 0.25% 204800.0 43 layer4.0.downsample.0 256 14 14 512 7 7 131072.0 0.10 12,819,968.0 6,422,528.0 724992.0 100352.0 1.76% 825344.0 44 layer4.0.downsample.1 512 7 7 512 7 7 1024.0 0.10 100,352.0 50,176.0 104448.0 100352.0 0.18% 204800.0 45 layer4.1.conv1 512 7 7 512 7 7 2359296.0 0.10 231,185,920.0 115,605,504.0 9537536.0 100352.0 7.26% 9637888.0 46 layer4.1.bn1 512 7 7 512 7 7 1024.0 0.10 100,352.0 50,176.0 104448.0 100352.0 0.23% 204800.0 47 layer4.1.relu 512 7 7 512 7 7 0.0 0.10 25,088.0 25,088.0 100352.0 100352.0 0.08% 200704.0 48 layer4.1.conv2 512 7 7 512 7 7 2359296.0 0.10 231,185,920.0 115,605,504.0 9537536.0 100352.0 6.57% 9637888.0 49 layer4.1.bn2 512 7 7 512 7 7 1024.0 0.10 100,352.0 50,176.0 104448.0 100352.0 0.23% 204800.0 50 avgpool 512 7 7 512 1 1 0.0 0.00 0.0 0.0 0.0 0.0 0.25% 0.0 51 fc 512 1000 513000.0 0.00 1,023,000.0 512,000.0 2054048.0 4000.0 0.71% 2058048.0 total 11689512.0 25.65 3,638,757,912.0 1,821,399,040.0 2054048.0 4000.0 100.00% 101756992.0 ================================================================================================================================================================= Total params: 11,689,512 ----------------------------------------------------------------------------------------------------------------------------------------------------------------- Total memory: 25.65MB Total MAdd: 3.64GMAdd Total Flops: 1.82GFlops Total MemR+W: 97.04MB Process finished with exit code 0
assert isinstance(input_size, (tuple, list)) and len(input_size) == 3
2.4 torchsummary法
还是一样,用pip install torchsummay进行安装,通过下方例子进行调用。
from torchsummary import summary model = resnet18() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') summary(model.to(device), input_size=(3, 224, 224), batch_size=8)
---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [8, 64, 112, 112] 9,408 BatchNorm2d-2 [8, 64, 112, 112] 128 ReLU-3 [8, 64, 112, 112] 0 MaxPool2d-4 [8, 64, 56, 56] 0 Conv2d-5 [8, 64, 56, 56] 36,864 BatchNorm2d-6 [8, 64, 56, 56] 128 ReLU-7 [8, 64, 56, 56] 0 Conv2d-8 [8, 64, 56, 56] 36,864 BatchNorm2d-9 [8, 64, 56, 56] 128 ReLU-10 [8, 64, 56, 56] 0 BasicBlock-11 [8, 64, 56, 56] 0 Conv2d-12 [8, 64, 56, 56] 36,864 BatchNorm2d-13 [8, 64, 56, 56] 128 ReLU-14 [8, 64, 56, 56] 0 Conv2d-15 [8, 64, 56, 56] 36,864 BatchNorm2d-16 [8, 64, 56, 56] 128 ReLU-17 [8, 64, 56, 56] 0 BasicBlock-18 [8, 64, 56, 56] 0 Conv2d-19 [8, 128, 28, 28] 73,728 BatchNorm2d-20 [8, 128, 28, 28] 256 ReLU-21 [8, 128, 28, 28] 0 Conv2d-22 [8, 128, 28, 28] 147,456 BatchNorm2d-23 [8, 128, 28, 28] 256 Conv2d-24 [8, 128, 28, 28] 8,192 BatchNorm2d-25 [8, 128, 28, 28] 256 ReLU-26 [8, 128, 28, 28] 0 BasicBlock-27 [8, 128, 28, 28] 0 Conv2d-28 [8, 128, 28, 28] 147,456 BatchNorm2d-29 [8, 128, 28, 28] 256 ReLU-30 [8, 128, 28, 28] 0 Conv2d-31 [8, 128, 28, 28] 147,456 BatchNorm2d-32 [8, 128, 28, 28] 256 ReLU-33 [8, 128, 28, 28] 0 BasicBlock-34 [8, 128, 28, 28] 0 Conv2d-35 [8, 256, 14, 14] 294,912 BatchNorm2d-36 [8, 256, 14, 14] 512 ReLU-37 [8, 256, 14, 14] 0 Conv2d-38 [8, 256, 14, 14] 589,824 BatchNorm2d-39 [8, 256, 14, 14] 512 Conv2d-40 [8, 256, 14, 14] 32,768 BatchNorm2d-41 [8, 256, 14, 14] 512 ReLU-42 [8, 256, 14, 14] 0 BasicBlock-43 [8, 256, 14, 14] 0 Conv2d-44 [8, 256, 14, 14] 589,824 BatchNorm2d-45 [8, 256, 14, 14] 512 ReLU-46 [8, 256, 14, 14] 0 Conv2d-47 [8, 256, 14, 14] 589,824 BatchNorm2d-48 [8, 256, 14, 14] 512 ReLU-49 [8, 256, 14, 14] 0 BasicBlock-50 [8, 256, 14, 14] 0 Conv2d-51 [8, 512, 7, 7] 1,179,648 BatchNorm2d-52 [8, 512, 7, 7] 1,024 ReLU-53 [8, 512, 7, 7] 0 Conv2d-54 [8, 512, 7, 7] 2,359,296 BatchNorm2d-55 [8, 512, 7, 7] 1,024 Conv2d-56 [8, 512, 7, 7] 131,072 BatchNorm2d-57 [8, 512, 7, 7] 1,024 ReLU-58 [8, 512, 7, 7] 0 BasicBlock-59 [8, 512, 7, 7] 0 Conv2d-60 [8, 512, 7, 7] 2,359,296 BatchNorm2d-61 [8, 512, 7, 7] 1,024 ReLU-62 [8, 512, 7, 7] 0 Conv2d-63 [8, 512, 7, 7] 2,359,296 BatchNorm2d-64 [8, 512, 7, 7] 1,024 ReLU-65 [8, 512, 7, 7] 0 BasicBlock-66 [8, 512, 7, 7] 0 AdaptiveAvgPool2d-67 [8, 512, 1, 1] 0 Linear-68 [8, 1000] 513,000 ================================================================ Total params: 11,689,512 Trainable params: 11,689,512 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 4.59 Forward/backward pass size (MB): 502.34 Params size (MB): 44.59 Estimated Total Size (MB): 551.53 ---------------------------------------------------------------- Process finished with exit code 0