Pytorch之保存读取模型实例
pytorch保存数据
pytorch保存数据的格式为.t7文件或者.pth文件,t7文件是沿用torch7中读取模型权重的方式。而pth文件是python中存储文件的常用格式。而在keras中则是使用.h5文件。
#保存模型示例代码 print('===>Savingmodels...') state={ 'state':model.state_dict(), 'epoch':epoch#将epoch一并保存 } ifnotos.path.isdir('checkpoint'): os.mkdir('checkpoint') torch.save(state,'./checkpoint/autoencoder.t7')
保存用到torch.save函数,注意该函数第一个参数可以是单个值也可以是字典,字典可以存更多你要保存的参数(不仅仅是权重数据)。
pytorch读取数据
pytorch读取数据使用的方法和我们平时使用预训练参数所用的方法是一样的,都是使用load_state_dict这个函数。
下方的代码和上方的保存代码可以搭配使用。
print('===>Tryresumefromcheckpoint') ifos.path.isdir('checkpoint'): try: checkpoint=torch.load('./checkpoint/autoencoder.t7') model.load_state_dict(checkpoint['state'])#从字典中依次读取 start_epoch=checkpoint['epoch'] print('===>Loadlastcheckpointdata') exceptFileNotFoundError: print('Can\'tfoundautoencoder.t7') else: start_epoch=0 print('===>Startfromscratch')
以上是pytorch读取的方法汇总,但是要注意,在使用官方的预处理模型进行读取时,一般使用的格式是pth,使用官方的模型读取命令会检查你模型的格式是否正确,如果不是使用官方提供模型通过下面的函数强行读取模型(将其他模型例如caffe模型转过来的模型放到指定目录下)会发生错误。
defvgg19(pretrained=False,**kwargs): """VGG19-layermodel(configuration"E") Args: pretrained(bool):IfTrue,returnsamodelpre-trainedonImageNet """ model=VGG(make_layers(cfg['E']),**kwargs) ifpretrained: model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) returnmodel
假如我们有从caffe模型转过来的pytorch模型([0-255,BGR]),我们可以使用:
model_dir='自己的模型地址' model=VGG() model.load_state_dict(torch.load(model_dir+'vgg_conv.pth'))
也就是pytorch的读取函数进行读取即可。
以上这篇Pytorch之保存读取模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。