vlambda博客
学习文章列表

打印网络模型的conv.weight(二)


上一篇博文讲到要区分module的“总”与“分”

需要添加一个判断语句


 for name, module in model.named_modules(): # module中有所有的分支 if hasattr(module, 'repvgg_convert'):

所以“总”就是我们定义的网络结构, 应该是一个class-类, 其中要有一个函数是:

repvgg_convert()。



当判断module中有该函数时, 进到下一步, 否则跳过



代码:


MyNet---构建2个block, 各自有一个"conv"


repvgg_convert---判断“总”与“分”, 并将“总”送进

                 get_equivalent_kernel_bias(self)函数


get_equivalent_kernel_bias(self)---将“总”中定义的2个   

               block传入_fuse_bn_tensor(self, branch)函数


_fuse_bn_tensor(self, branch)---得到branch.conv.weight

                            及branch.conv.bias并且各自相加


import torchimport torch.nn as nn
class MyNet(nn.Module): def __init__(self, inplanes, planes): super(MyNet, self,).__init__() self.conv_block=torch.nn.Sequential() self.conv_block.add_module("conv",torch.nn.Conv2d(inplanes, planes, 3, 1, 1)) self.conv_block.add_module("relu1",torch.nn.ReLU()) self.conv_block.add_module("pool1",torch.nn.MaxPool2d(2)) self.dense_block = torch.nn.Sequential() self.dense_block.add_module("conv",torch.nn.Conv2d(planes, planes, 3,1,1) ) self.dense_block.add_module("relu2",torch.nn.ReLU()) self.dense_block.add_module("dense2",torch.nn.Linear(32 * 3 * 3, 10)) def forward(self, x): conv_out = self.conv_block(x) res = conv_out.view(conv_out.size(0), -1) out = self.dense_block(res) return out
def get_equivalent_kernel_bias(self):
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv_block) kernel1x1, bias1x1 = self._fuse_bn_tensor(self.dense_block) return kernel3x3 +(kernel1x1), bias3x3 + bias1x1
# def _pad_1x1_to_3x3_tensor(self, kernel1x1): # if kernel1x1 is None: # return 0 # else: # return torch.nn.functional.pad(kernel1x1, [1,1,1,1])
def _fuse_bn_tensor(self, branch): # 将BN合并到卷积中 if branch is None: return 0, 0 if isinstance(branch, nn.Sequential): # 判断branch是否为nn.Sequential # branch在get_equivalent_kernel_bias中调用;而branch的定义则来自28--33行。定义的卷积、正则化 kernel = branch.conv.weight bias = branch.conv.bias return kernel, bias
# 调用融合后的卷积与偏置 def repvgg_convert(self): kernel, bias = self.get_equivalent_kernel_bias() return kernel.detach().cpu().numpy(), bias.detach().cpu().numpy(),

model = MyNet(32,32)
# print(model.conv_block.conv1.weight)
for name, module in model.named_modules():
if hasattr(module, 'repvgg_convert'): kernel, bias = module.repvgg_convert() print(kernel)
# print('modules:', module) # print("******************")
# print(module.conv_block.conv1.weight)