vlambda博客
学习文章列表

网络模型传入Input核验



import torchimport torch.nn as nn
input = torch.randn(1,3)

print(input)

class Bottleneck(nn.Module):
def __init__(self, inplanes, planes): super(Bottleneck, self).__init__()
self.inplanes = inplanes self.planes = planes self.connected_layer = nn.Linear(inplanes , planes )
def forward(self, x) : x = self.connected_layer(x)
return x
model = Bottleneck(31)print(model)
out = model(input)print(out.size())


手生了。。。。