pytorch中获取模型input/output shape实例
Pytorch官方目前无法像tensorflow,caffe那样直接给出shape信息,详见
https://github.com/pytorch/pytorch/pull/3043
以下代码算一种workaround。由于CNN,RNN等模块实现不一样,添加其他模块支持可能需要改代码。
例如RNN中bias是bool类型,其权重也不是存于weight属性中,不过我们只关注shape够用了。
该方法必须构造一个输入调用forward后(model(x)调用)才可获取shape
#coding:utf-8
fromcollectionsimportOrderedDict
importtorch
fromtorch.autogradimportVariable
importtorch.nnasnn
importmodels.crnnascrnn
importjson
defget_output_size(summary_dict,output):
ifisinstance(output,tuple):
foriinxrange(len(output)):
summary_dict[i]=OrderedDict()
summary_dict[i]=get_output_size(summary_dict[i],output[i])
else:
summary_dict['output_shape']=list(output.size())
returnsummary_dict
defsummary(input_size,model):
defregister_hook(module):
defhook(module,input,output):
class_name=str(module.__class__).split('.')[-1].split("'")[0]
module_idx=len(summary)
m_key='%s-%i'%(class_name,module_idx+1)
summary[m_key]=OrderedDict()
summary[m_key]['input_shape']=list(input[0].size())
summary[m_key]=get_output_size(summary[m_key],output)
params=0
ifhasattr(module,'weight'):
params+=torch.prod(torch.LongTensor(list(module.weight.size())))
ifmodule.weight.requires_grad:
summary[m_key]['trainable']=True
else:
summary[m_key]['trainable']=False
#ifhasattr(module,'bias'):
#params+=torch.prod(torch.LongTensor(list(module.bias.size())))
summary[m_key]['nb_params']=params
ifnotisinstance(module,nn.Sequential)and\
notisinstance(module,nn.ModuleList)and\
not(module==model):
hooks.append(module.register_forward_hook(hook))
#checkiftherearemultipleinputstothenetwork
ifisinstance(input_size[0],(list,tuple)):
x=[Variable(torch.rand(1,*in_size))forin_sizeininput_size]
else:
x=Variable(torch.rand(1,*input_size))
#createproperties
summary=OrderedDict()
hooks=[]
#registerhook
model.apply(register_hook)
#makeaforwardpass
model(x)
#removethesehooks
forhinhooks:
h.remove()
returnsummary
crnn=crnn.CRNN(32,1,3755,256,1)
x=summary([1,32,128],crnn)
printjson.dumps(x)
以pytorch版CRNN为例,输出shape如下
{
"Conv2d-1":{
"input_shape":[1,1,32,128],
"output_shape":[1,64,32,128],
"trainable":true,
"nb_params":576
},
"ReLU-2":{
"input_shape":[1,64,32,128],
"output_shape":[1,64,32,128],
"nb_params":0
},
"MaxPool2d-3":{
"input_shape":[1,64,32,128],
"output_shape":[1,64,16,64],
"nb_params":0
},
"Conv2d-4":{
"input_shape":[1,64,16,64],
"output_shape":[1,128,16,64],
"trainable":true,
"nb_params":73728
},
"ReLU-5":{
"input_shape":[1,128,16,64],
"output_shape":[1,128,16,64],
"nb_params":0
},
"MaxPool2d-6":{
"input_shape":[1,128,16,64],
"output_shape":[1,128,8,32],
"nb_params":0
},
"Conv2d-7":{
"input_shape":[1,128,8,32],
"output_shape":[1,256,8,32],
"trainable":true,
"nb_params":294912
},
"BatchNorm2d-8":{
"input_shape":[1,256,8,32],
"output_shape":[1,256,8,32],
"trainable":true,
"nb_params":256
},
"ReLU-9":{
"input_shape":[1,256,8,32],
"output_shape":[1,256,8,32],
"nb_params":0
},
"Conv2d-10":{
"input_shape":[1,256,8,32],
"output_shape":[1,256,8,32],
"trainable":true,
"nb_params":589824
},
"ReLU-11":{
"input_shape":[1,256,8,32],
"output_shape":[1,256,8,32],
"nb_params":0
},
"MaxPool2d-12":{
"input_shape":[1,256,8,32],
"output_shape":[1,256,4,33],
"nb_params":0
},
"Conv2d-13":{
"input_shape":[1,256,4,33],
"output_shape":[1,512,4,33],
"trainable":true,
"nb_params":1179648
},
"BatchNorm2d-14":{
"input_shape":[1,512,4,33],
"output_shape":[1,512,4,33],
"trainable":true,
"nb_params":512
},
"ReLU-15":{
"input_shape":[1,512,4,33],
"output_shape":[1,512,4,33],
"nb_params":0
},
"Conv2d-16":{
"input_shape":[1,512,4,33],
"output_shape":[1,512,4,33],
"trainable":true,
"nb_params":2359296
},
"ReLU-17":{
"input_shape":[1,512,4,33],
"output_shape":[1,512,4,33],
"nb_params":0
},
"MaxPool2d-18":{
"input_shape":[1,512,4,33],
"output_shape":[1,512,2,34],
"nb_params":0
},
"Conv2d-19":{
"input_shape":[1,512,2,34],
"output_shape":[1,512,1,33],
"trainable":true,
"nb_params":1048576
},
"BatchNorm2d-20":{
"input_shape":[1,512,1,33],
"output_shape":[1,512,1,33],
"trainable":true,
"nb_params":512
},
"ReLU-21":{
"input_shape":[1,512,1,33],
"output_shape":[1,512,1,33],
"nb_params":0
},
"LSTM-22":{
"input_shape":[33,1,512],
"0":{
"output_shape":[33,1,512]
},
"1":{
"0":{
"output_shape":[2,1,256]
},
"1":{
"output_shape":[2,1,256]
}
},
"nb_params":0
},
"Linear-23":{
"input_shape":[33,512],
"output_shape":[33,256],
"trainable":true,
"nb_params":131072
},
"BidirectionalLSTM-24":{
"input_shape":[33,1,512],
"output_shape":[33,1,256],
"nb_params":0
},
"LSTM-25":{
"input_shape":[33,1,256],
"0":{
"output_shape":[33,1,512]
},
"1":{
"0":{
"output_shape":[2,1,256]
},
"1":{
"output_shape":[2,1,256]
}
},
"nb_params":0
},
"Linear-26":{
"input_shape":[33,512],
"output_shape":[33,3755],
"trainable":true,
"nb_params":1922560
},
"BidirectionalLSTM-27":{
"input_shape":[33,1,256],
"output_shape":[33,1,3755],
"nb_params":0
}
}
以上这篇pytorch中获取模型input/outputshape实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。