打印网络模型的conv.weight(二)
上一篇博文讲到要区分module的“总”与“分”
需要添加一个判断语句
for name, module in model.named_modules(): # module中有所有的分支if hasattr(module, 'repvgg_convert'):
所以“总”就是我们定义的网络结构, 应该是一个class-类, 其中要有一个函数是:
repvgg_convert()。
当判断module中有该函数时, 进到下一步, 否则跳过
代码:
MyNet---构建2个block, 各自有一个"conv"
repvgg_convert---判断“总”与“分”, 并将“总”送进
get_equivalent_kernel_bias(self)函数
get_equivalent_kernel_bias(self)---将“总”中定义的2个
block传入_fuse_bn_tensor(self, branch)函数
_fuse_bn_tensor(self, branch)---得到branch.conv.weight
及branch.conv.bias并且各自相加
import torchimport torch.nn as nnclass MyNet(nn.Module):def __init__(self, inplanes, planes):super(MyNet, self,).__init__()self.conv_block=torch.nn.Sequential()self.conv_block.add_module("conv",torch.nn.Conv2d(inplanes, planes, 3, 1, 1))self.conv_block.add_module("relu1",torch.nn.ReLU())self.conv_block.add_module("pool1",torch.nn.MaxPool2d(2))self.dense_block = torch.nn.Sequential()self.dense_block.add_module("conv",torch.nn.Conv2d(planes, planes, 3,1,1) )self.dense_block.add_module("relu2",torch.nn.ReLU())self.dense_block.add_module("dense2",torch.nn.Linear(32 * 3 * 3, 10))def forward(self, x):conv_out = self.conv_block(x)res = conv_out.view(conv_out.size(0), -1)out = self.dense_block(res)return outdef get_equivalent_kernel_bias(self):kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv_block)kernel1x1, bias1x1 = self._fuse_bn_tensor(self.dense_block)return kernel3x3 +(kernel1x1), bias3x3 + bias1x1# def _pad_1x1_to_3x3_tensor(self, kernel1x1):# if kernel1x1 is None:# return 0# else:# return torch.nn.functional.pad(kernel1x1, [1,1,1,1])def _fuse_bn_tensor(self, branch): # 将BN合并到卷积中if branch is None:return 0, 0if isinstance(branch, nn.Sequential): # 判断branch是否为nn.Sequential# branch在get_equivalent_kernel_bias中调用;而branch的定义则来自28--33行。定义的卷积、正则化kernel = branch.conv.weightbias = branch.conv.biasreturn kernel, bias# 调用融合后的卷积与偏置def repvgg_convert(self):kernel, bias = self.get_equivalent_kernel_bias()return kernel.detach().cpu().numpy(), bias.detach().cpu().numpy(),model = MyNet(32,32)# print(model.conv_block.conv1.weight)for name, module in model.named_modules():if hasattr(module, 'repvgg_convert'):kernel, bias = module.repvgg_convert()print(kernel)# print('modules:', module)# print("******************")# print(module.conv_block.conv1.weight)
