pytorch 批次遍历数据集打印数据的例子
我就废话不多说了,直接上代码吧!
fromosimportlistdir importos fromtimeimporttime importtorch.utils.dataasdata importtorchvision.transformsastransforms fromtorch.utils.dataimportDataLoader defprintProgressBar(iteration,total,prefix='',suffix='',decimals=1,length=100, fill='=',empty='',tip='>',begin='[',end=']',done="[DONE]",clear=True): percent=("{0:."+str(decimals)+"f}").format(100*(iteration/float(total))) filledLength=int(length*iteration//total) bar=fill*filledLength ifiteration!=total: bar=bar+tip bar=bar+empty*(length-filledLength-len(tip)) display='\r{prefix}{begin}{bar}{end}{percent}%{suffix}'\ .format(prefix=prefix,begin=begin,bar=bar,end=end,percent=percent,suffix=suffix) print(display,end=''),#commaafterprint()requiredforpython2 ifiteration==total:#printwithnewlineoncomplete ifclear:#displaygivencompletemessagewithspacesto'erase'previousprogressbar finish='\r{prefix}{done}'.format(prefix=prefix,done=done) ifhasattr(str,'decode'):#handlepython2non-unicodestringsforproperlengthmeasure finish=finish.decode('utf-8') display=display.decode('utf-8') clear=''*max(len(display)-len(finish),0) print(finish+clear) else: print('') classDatasetFromFolder(data.Dataset): def__init__(self,image_dir): super(DatasetFromFolder,self).__init__() self.photo_path=os.path.join(image_dir,"a") self.sketch_path=os.path.join(image_dir,"b") self.image_filenames=[xforxinlistdir(self.photo_path)ifis_image_file(x)] transform_list=[transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))] self.transform=transforms.Compose(transform_list) def__getitem__(self,index): #LoadImage input=load_img(os.path.join(self.photo_path,self.image_filenames[index])) input=self.transform(input) target=load_img(os.path.join(self.sketch_path,self.image_filenames[index])) target=self.transform(target) returninput,target def__len__(self): returnlen(self.image_filenames) if__name__=='__main__': dataset=DatasetFromFolder("./dataset/facades/train") dataloader=DataLoader(dataset=dataset,num_workers=8,batch_size=1,shuffle=True) total=len(dataloader) forepochinrange(20): t0=time() fori,batchinenumerate(dataloader): real_a,real_b=batch[0],batch[1] printProgressBar(i+1,total+1, length=20, prefix='Epoch%s'%str(1), suffix=',d_loss:%d'%1) printProgressBar(total,total, done='Epoch[%s]'%str(epoch)+ ',time:%.2fs'%(time()-t0) )
以上这篇pytorch批次遍历数据集打印数据的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。