打印网络模型的conv.weight(二)
上一篇博文讲到要区分module的“总”与“分”
需要添加一个判断语句
for name, module in model.named_modules(): # module中有所有的分支
if hasattr(module, 'repvgg_convert'):
所以“总”就是我们定义的网络结构, 应该是一个class-类, 其中要有一个函数是:
repvgg_convert()。
当判断module中有该函数时, 进到下一步, 否则跳过
代码:
MyNet
---构建2个block, 各自有一个"conv"
repvgg_convert---
判断“总”与“分”, 并将“总”送进
get_equivalent_kernel_bias(self)函数
get_equivalent_kernel_bias(self)
---
将“总”中定义的2个
block传入
_fuse_bn_tensor(self, branch)函数
_fuse_bn_tensor(self, branch)
---得到branch.conv.weight
及branch.conv.bias并且各自相加
import torch
import torch.nn as nn
class MyNet(nn.Module):
def __init__(self, inplanes, planes):
super(MyNet, self,).__init__()
self.conv_block=torch.nn.Sequential()
self.conv_block.add_module("conv",torch.nn.Conv2d(inplanes, planes, 3, 1, 1))
self.conv_block.add_module("relu1",torch.nn.ReLU())
self.conv_block.add_module("pool1",torch.nn.MaxPool2d(2))
self.dense_block = torch.nn.Sequential()
self.dense_block.add_module("conv",torch.nn.Conv2d(planes, planes, 3,1,1) )
self.dense_block.add_module("relu2",torch.nn.ReLU())
self.dense_block.add_module("dense2",torch.nn.Linear(32 * 3 * 3, 10))
def forward(self, x):
conv_out = self.conv_block(x)
res = conv_out.view(conv_out.size(0), -1)
out = self.dense_block(res)
return out
def get_equivalent_kernel_bias(self):
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv_block)
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.dense_block)
return kernel3x3 +(kernel1x1), bias3x3 + bias1x1
# def _pad_1x1_to_3x3_tensor(self, kernel1x1):
# if kernel1x1 is None:
# return 0
# else:
# return torch.nn.functional.pad(kernel1x1, [1,1,1,1])
def _fuse_bn_tensor(self, branch): # 将BN合并到卷积中
if branch is None:
return 0, 0
if isinstance(branch, nn.Sequential): # 判断branch是否为nn.Sequential
# branch在get_equivalent_kernel_bias中调用;而branch的定义则来自28--33行。定义的卷积、正则化
kernel = branch.conv.weight
bias = branch.conv.bias
return kernel, bias
# 调用融合后的卷积与偏置
def repvgg_convert(self):
kernel, bias = self.get_equivalent_kernel_bias()
return kernel.detach().cpu().numpy(), bias.detach().cpu().numpy(),
model = MyNet(32,32)
# print(model.conv_block.conv1.weight)
for name, module in model.named_modules():
if hasattr(module, 'repvgg_convert'):
kernel, bias = module.repvgg_convert()
print(kernel)
# print('modules:', module)
# print("******************")
# print(module.conv_block.conv1.weight)