import collections
import math
import os
import shutil
import pandas as pd
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
我们提供包含前1000个训练图像和5个随机测试图像的数据集的小规模样本
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip',
'2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')
demo = True
if demo:
data_dir = d2l.download_extract('cifar10_tiny')
else:
data_dir = '../data/cifar-10/'
Downloading ../data/kaggle_cifar10_tiny.zip from http://d2l-data.s3-accelerate.amazonaws.com/kaggle_cifar10_tiny.zip...
整理数据集
def read_csv_labels(fname):
"""读取‘fname’来给标签字典返回一个文件名"""
with open(fname, 'r') as f:
lines = f.readlines()[1:]
tokens = [l.rstrip().split(',') for l in lines]
return dict(((name, label) for name, label in tokens))
labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
labels
{'1': 'frog',
'2': 'truck',
'3': 'truck',
'4': 'deer',
'5': 'automobile',
'6': 'automobile',
'7': 'bird',
'8': 'horse',
'9': 'ship',
'10': 'cat',
'11': 'deer',
'12': 'horse',
'13': 'horse',
'14': 'bird',
'15': 'truck',
'16': 'truck',
'17': 'truck',
'18': 'cat',
'19': 'bird',
'20': 'frog',
'21': 'deer',
'22': 'cat',
'23': 'frog',
'24': 'frog',
'25': 'bird',
'26': 'frog',
'27': 'cat',
'28': 'dog',
'29': 'deer',
'30': 'airplane',
'31': 'airplane',
'32': 'truck',
'33': 'automobile',
'34': 'cat',
'35': 'deer',
'36': 'airplane',
'37': 'cat',
'38': 'horse',
'39': 'cat',
'40': 'cat',
'41': 'dog',
'42': 'bird',
'43': 'bird',
'44': 'horse',
'45': 'automobile',
'46': 'automobile',
'47': 'automobile',
'48': 'bird',
'49': 'bird',
'50': 'airplane',
'51': 'truck',
'52': 'dog',
'53': 'horse',
'54': 'truck',
'55': 'bird',
'56': 'bird',
'57': 'dog',
'58': 'bird',
'59': 'deer',
'60': 'cat',
'61': 'automobile',
'62': 'automobile',
'63': 'ship',
'64': 'bird',
'65': 'automobile',
'66': 'automobile',
'67': 'deer',
'68': 'truck',
'69': 'horse',
'70': 'ship',
'71': 'dog',
'72': 'truck',
'73': 'frog',
'74': 'horse',
'75': 'cat',
'76': 'automobile',
'77': 'truck',
'78': 'airplane',
'79': 'cat',
'80': 'automobile',
'81': 'cat',
'82': 'dog',
'83': 'deer',
'84': 'dog',
'85': 'horse',
'86': 'horse',
'87': 'deer',
'88': 'horse',
'89': 'truck',
'90': 'deer',
'91': 'bird',
'92': 'cat',
'93': 'ship',
'94': 'airplane',
'95': 'automobile',
'96': 'frog',
'97': 'automobile',
'98': 'automobile',
'99': 'deer',
'100': 'automobile',
'101': 'ship',
'102': 'cat',
'103': 'truck',
'104': 'frog',
'105': 'frog',
'106': 'automobile',
'107': 'ship',
'108': 'dog',
'109': 'bird',
'110': 'truck',
'111': 'truck',
'112': 'ship',
'113': 'automobile',
'114': 'horse',
'115': 'horse',
'116': 'airplane',
'117': 'airplane',
'118': 'frog',
'119': 'truck',
'120': 'automobile',
'121': 'bird',
'122': 'bird',
'123': 'truck',
'124': 'bird',
'125': 'frog',
'126': 'frog',
'127': 'automobile',
'128': 'truck',
'129': 'dog',
'130': 'airplane',
'131': 'deer',
'132': 'horse',
'133': 'frog',
'134': 'horse',
'135': 'automobile',
'136': 'ship',
'137': 'automobile',
'138': 'automobile',
'139': 'bird',
'140': 'ship',
'141': 'automobile',
'142': 'cat',
'143': 'cat',
'144': 'frog',
'145': 'bird',
'146': 'deer',
'147': 'truck',
'148': 'truck',
'149': 'dog',
'150': 'deer',
'151': 'cat',
'152': 'frog',
'153': 'horse',
'154': 'deer',
'155': 'frog',
'156': 'ship',
'157': 'dog',
'158': 'dog',
'159': 'deer',
'160': 'cat',
'161': 'automobile',
'162': 'ship',
'163': 'deer',
'164': 'horse',
'165': 'frog',
'166': 'airplane',
'167': 'truck',
'168': 'dog',
'169': 'automobile',
'170': 'cat',
'171': 'ship',
'172': 'bird',
'173': 'horse',
'174': 'dog',
'175': 'cat',
'176': 'deer',
'177': 'automobile',
'178': 'dog',
'179': 'horse',
'180': 'airplane',
'181': 'deer',
'182': 'horse',
'183': 'dog',
'184': 'dog',
'185': 'automobile',
'186': 'airplane',
'187': 'truck',
'188': 'frog',
'189': 'truck',
'190': 'airplane',
'191': 'ship',
'192': 'horse',
'193': 'ship',
'194': 'ship',
'195': 'bird',
'196': 'dog',
'197': 'bird',
'198': 'cat',
'199': 'dog',
'200': 'airplane',
'201': 'frog',
'202': 'automobile',
'203': 'truck',
'204': 'cat',
'205': 'frog',
'206': 'truck',
'207': 'automobile',
'208': 'cat',
'209': 'truck',
'210': 'frog',
'211': 'frog',
'212': 'horse',
'213': 'automobile',
'214': 'airplane',
'215': 'truck',
'216': 'dog',
'217': 'ship',
'218': 'dog',
'219': 'bird',
'220': 'truck',
'221': 'airplane',
'222': 'ship',
'223': 'ship',
'224': 'airplane',
'225': 'frog',
'226': 'truck',
'227': 'automobile',
'228': 'automobile',
'229': 'frog',
'230': 'cat',
'231': 'horse',
'232': 'frog',
'233': 'frog',
'234': 'airplane',
'235': 'frog',
'236': 'frog',
'237': 'automobile',
'238': 'horse',
'239': 'automobile',
'240': 'dog',
'241': 'ship',
'242': 'cat',
'243': 'frog',
'244': 'frog',
'245': 'ship',
'246': 'frog',
'247': 'ship',
'248': 'deer',
'249': 'frog',
'250': 'frog',
'251': 'automobile',
'252': 'cat',
'253': 'ship',
'254': 'cat',
'255': 'deer',
'256': 'automobile',
'257': 'horse',
'258': 'automobile',
'259': 'cat',
'260': 'ship',
'261': 'dog',
'262': 'automobile',
'263': 'automobile',
'264': 'deer',
'265': 'airplane',
'266': 'truck',
'267': 'cat',
'268': 'horse',
'269': 'deer',
'270': 'truck',
'271': 'truck',
'272': 'bird',
'273': 'deer',
'274': 'truck',
'275': 'truck',
'276': 'automobile',
'277': 'airplane',
'278': 'dog',
'279': 'truck',
'280': 'airplane',
'281': 'ship',
'282': 'bird',
'283': 'automobile',
'284': 'bird',
'285': 'airplane',
'286': 'dog',
'287': 'frog',
'288': 'cat',
'289': 'bird',
'290': 'horse',
'291': 'ship',
'292': 'ship',
'293': 'frog',
'294': 'airplane',
'295': 'horse',
'296': 'truck',
'297': 'deer',
'298': 'dog',
'299': 'frog',
'300': 'deer',
'301': 'bird',
'302': 'automobile',
'303': 'automobile',
'304': 'bird',
'305': 'automobile',
'306': 'dog',
'307': 'truck',
'308': 'truck',
'309': 'airplane',
'310': 'ship',
'311': 'deer',
'312': 'automobile',
'313': 'automobile',
'314': 'frog',
'315': 'cat',
'316': 'cat',
'317': 'truck',
'318': 'airplane',
'319': 'horse',
'320': 'truck',
'321': 'horse',
'322': 'horse',
'323': 'truck',
'324': 'automobile',
'325': 'dog',
'326': 'automobile',
'327': 'frog',
'328': 'frog',
'329': 'ship',
'330': 'horse',
'331': 'automobile',
'332': 'cat',
'333': 'airplane',
'334': 'cat',
'335': 'cat',
'336': 'bird',
'337': 'deer',
'338': 'dog',
'339': 'horse',
'340': 'dog',
'341': 'truck',
'342': 'airplane',
'343': 'cat',
'344': 'deer',
'345': 'airplane',
'346': 'deer',
'347': 'deer',
'348': 'frog',
'349': 'airplane',
'350': 'airplane',
'351': 'frog',
'352': 'frog',
'353': 'airplane',
'354': 'ship',
'355': 'automobile',
'356': 'frog',
'357': 'bird',
'358': 'truck',
'359': 'bird',
'360': 'dog',
'361': 'truck',
'362': 'frog',
'363': 'horse',
'364': 'deer',
'365': 'automobile',
'366': 'ship',
'367': 'horse',
'368': 'cat',
'369': 'frog',
'370': 'truck',
'371': 'cat',
'372': 'airplane',
'373': 'deer',
'374': 'airplane',
'375': 'dog',
'376': 'automobile',
'377': 'airplane',
'378': 'cat',
'379': 'deer',
'380': 'ship',
'381': 'dog',
'382': 'deer',
'383': 'horse',
'384': 'bird',
'385': 'cat',
'386': 'truck',
'387': 'horse',
'388': 'frog',
'389': 'horse',
'390': 'automobile',
'391': 'deer',
'392': 'horse',
'393': 'airplane',
'394': 'automobile',
'395': 'horse',
'396': 'cat',
'397': 'automobile',
'398': 'ship',
'399': 'deer',
'400': 'deer',
'401': 'bird',
'402': 'airplane',
'403': 'bird',
'404': 'bird',
'405': 'airplane',
'406': 'airplane',
'407': 'truck',
'408': 'airplane',
'409': 'truck',
'410': 'frog',
'411': 'ship',
'412': 'bird',
'413': 'horse',
'414': 'horse',
'415': 'deer',
'416': 'airplane',
'417': 'cat',
'418': 'airplane',
'419': 'ship',
'420': 'truck',
'421': 'deer',
'422': 'bird',
'423': 'horse',
'424': 'bird',
'425': 'dog',
'426': 'bird',
'427': 'dog',
'428': 'automobile',
'429': 'truck',
'430': 'deer',
'431': 'ship',
'432': 'dog',
'433': 'automobile',
'434': 'horse',
'435': 'deer',
'436': 'deer',
'437': 'airplane',
'438': 'frog',
'439': 'truck',
'440': 'airplane',
'441': 'horse',
'442': 'ship',
'443': 'ship',
'444': 'truck',
'445': 'truck',
'446': 'cat',
'447': 'cat',
'448': 'deer',
'449': 'airplane',
'450': 'deer',
'451': 'dog',
'452': 'frog',
'453': 'frog',
'454': 'airplane',
'455': 'automobile',
'456': 'airplane',
'457': 'ship',
'458': 'airplane',
'459': 'deer',
'460': 'ship',
'461': 'ship',
'462': 'automobile',
'463': 'dog',
'464': 'bird',
'465': 'frog',
'466': 'ship',
'467': 'automobile',
'468': 'airplane',
'469': 'airplane',
'470': 'horse',
'471': 'horse',
'472': 'dog',
'473': 'truck',
'474': 'frog',
'475': 'bird',
'476': 'ship',
'477': 'cat',
'478': 'deer',
'479': 'horse',
'480': 'cat',
'481': 'truck',
'482': 'airplane',
'483': 'automobile',
'484': 'bird',
'485': 'deer',
'486': 'ship',
'487': 'automobile',
'488': 'ship',
'489': 'frog',
'490': 'deer',
'491': 'deer',
'492': 'dog',
'493': 'horse',
'494': 'automobile',
'495': 'cat',
'496': 'truck',
'497': 'ship',
'498': 'airplane',
'499': 'automobile',
'500': 'horse',
'501': 'dog',
'502': 'ship',
'503': 'bird',
'504': 'ship',
'505': 'airplane',
'506': 'deer',
'507': 'automobile',
'508': 'ship',
'509': 'truck',
'510': 'ship',
'511': 'bird',
'512': 'truck',
'513': 'truck',
'514': 'bird',
'515': 'horse',
'516': 'dog',
'517': 'horse',
'518': 'cat',
'519': 'ship',
'520': 'ship',
'521': 'deer',
'522': 'deer',
'523': 'bird',
'524': 'horse',
'525': 'automobile',
'526': 'frog',
'527': 'deer',
'528': 'airplane',
'529': 'deer',
'530': 'frog',
'531': 'truck',
'532': 'horse',
'533': 'frog',
'534': 'bird',
'535': 'dog',
'536': 'dog',
'537': 'automobile',
'538': 'horse',
'539': 'bird',
'540': 'bird',
'541': 'bird',
'542': 'truck',
'543': 'dog',
'544': 'deer',
'545': 'bird',
'546': 'horse',
'547': 'ship',
'548': 'automobile',
'549': 'cat',
'550': 'deer',
'551': 'cat',
'552': 'horse',
'553': 'frog',
'554': 'truck',
'555': 'ship',
'556': 'airplane',
'557': 'frog',
'558': 'airplane',
'559': 'bird',
'560': 'bird',
'561': 'bird',
'562': 'automobile',
'563': 'ship',
'564': 'deer',
'565': 'airplane',
'566': 'automobile',
'567': 'ship',
'568': 'ship',
'569': 'automobile',
'570': 'dog',
'571': 'horse',
'572': 'frog',
'573': 'deer',
'574': 'dog',
'575': 'ship',
'576': 'horse',
'577': 'automobile',
'578': 'truck',
'579': 'automobile',
'580': 'truck',
'581': 'ship',
'582': 'deer',
'583': 'horse',
'584': 'cat',
'585': 'ship',
'586': 'ship',
'587': 'bird',
'588': 'frog',
'589': 'frog',
'590': 'horse',
'591': 'automobile',
'592': 'frog',
'593': 'ship',
'594': 'automobile',
'595': 'truck',
'596': 'horse',
'597': 'ship',
'598': 'cat',
'599': 'airplane',
'600': 'automobile',
'601': 'airplane',
'602': 'ship',
'603': 'ship',
'604': 'cat',
'605': 'airplane',
'606': 'airplane',
'607': 'automobile',
'608': 'dog',
'609': 'airplane',
'610': 'ship',
'611': 'ship',
'612': 'horse',
'613': 'truck',
'614': 'truck',
'615': 'airplane',
'616': 'truck',
'617': 'deer',
'618': 'automobile',
'619': 'cat',
'620': 'frog',
'621': 'frog',
'622': 'deer',
'623': 'deer',
'624': 'horse',
'625': 'dog',
'626': 'frog',
'627': 'airplane',
'628': 'ship',
'629': 'airplane',
'630': 'cat',
'631': 'bird',
'632': 'ship',
'633': 'deer',
'634': 'frog',
'635': 'truck',
'636': 'truck',
'637': 'horse',
'638': 'airplane',
'639': 'cat',
'640': 'cat',
'641': 'frog',
'642': 'horse',
'643': 'deer',
'644': 'truck',
'645': 'automobile',
'646': 'frog',
'647': 'bird',
'648': 'horse',
'649': 'bird',
'650': 'bird',
'651': 'airplane',
'652': 'frog',
'653': 'horse',
'654': 'dog',
'655': 'horse',
'656': 'frog',
'657': 'ship',
'658': 'truck',
'659': 'airplane',
'660': 'truck',
'661': 'deer',
'662': 'deer',
'663': 'horse',
'664': 'airplane',
'665': 'truck',
'666': 'deer',
'667': 'truck',
'668': 'frog',
'669': 'truck',
'670': 'deer',
'671': 'dog',
'672': 'horse',
'673': 'truck',
'674': 'bird',
'675': 'deer',
'676': 'dog',
'677': 'automobile',
'678': 'deer',
'679': 'cat',
'680': 'truck',
'681': 'frog',
'682': 'dog',
'683': 'frog',
'684': 'truck',
'685': 'cat',
'686': 'cat',
'687': 'dog',
'688': 'airplane',
'689': 'horse',
'690': 'bird',
'691': 'automobile',
'692': 'cat',
'693': 'frog',
'694': 'deer',
'695': 'airplane',
'696': 'airplane',
'697': 'bird',
'698': 'dog',
'699': 'airplane',
'700': 'automobile',
'701': 'airplane',
'702': 'bird',
'703': 'cat',
'704': 'truck',
'705': 'ship',
'706': 'deer',
'707': 'truck',
'708': 'ship',
'709': 'airplane',
'710': 'bird',
'711': 'frog',
'712': 'deer',
'713': 'deer',
'714': 'airplane',
'715': 'automobile',
'716': 'ship',
'717': 'ship',
'718': 'cat',
'719': 'frog',
'720': 'truck',
'721': 'frog',
'722': 'frog',
'723': 'horse',
'724': 'ship',
'725': 'bird',
'726': 'deer',
'727': 'dog',
'728': 'horse',
'729': 'frog',
'730': 'dog',
'731': 'cat',
'732': 'airplane',
'733': 'dog',
'734': 'airplane',
'735': 'dog',
'736': 'airplane',
'737': 'ship',
'738': 'bird',
'739': 'frog',
'740': 'horse',
'741': 'cat',
'742': 'ship',
'743': 'bird',
'744': 'automobile',
'745': 'horse',
'746': 'frog',
'747': 'horse',
'748': 'automobile',
'749': 'airplane',
'750': 'truck',
'751': 'dog',
'752': 'dog',
'753': 'airplane',
'754': 'automobile',
'755': 'horse',
'756': 'frog',
'757': 'truck',
'758': 'airplane',
'759': 'deer',
'760': 'horse',
'761': 'horse',
'762': 'automobile',
'763': 'dog',
'764': 'truck',
'765': 'deer',
'766': 'airplane',
'767': 'ship',
'768': 'dog',
'769': 'truck',
'770': 'truck',
'771': 'frog',
'772': 'horse',
'773': 'automobile',
'774': 'ship',
'775': 'cat',
'776': 'bird',
'777': 'cat',
'778': 'ship',
'779': 'bird',
'780': 'bird',
'781': 'deer',
'782': 'frog',
'783': 'airplane',
'784': 'airplane',
'785': 'dog',
'786': 'cat',
'787': 'ship',
'788': 'bird',
'789': 'cat',
'790': 'horse',
'791': 'bird',
'792': 'truck',
'793': 'cat',
'794': 'ship',
'795': 'horse',
'796': 'ship',
'797': 'bird',
'798': 'horse',
'799': 'truck',
'800': 'airplane',
'801': 'bird',
'802': 'cat',
'803': 'bird',
'804': 'bird',
'805': 'bird',
'806': 'cat',
'807': 'cat',
'808': 'frog',
'809': 'bird',
'810': 'cat',
'811': 'bird',
'812': 'ship',
'813': 'airplane',
'814': 'dog',
'815': 'dog',
'816': 'automobile',
'817': 'deer',
'818': 'dog',
'819': 'frog',
'820': 'frog',
'821': 'bird',
'822': 'horse',
'823': 'airplane',
'824': 'automobile',
'825': 'horse',
'826': 'horse',
'827': 'ship',
'828': 'bird',
'829': 'truck',
'830': 'bird',
'831': 'bird',
'832': 'deer',
'833': 'bird',
'834': 'automobile',
'835': 'automobile',
'836': 'automobile',
'837': 'frog',
'838': 'frog',
'839': 'frog',
'840': 'dog',
'841': 'automobile',
'842': 'automobile',
'843': 'horse',
'844': 'airplane',
'845': 'deer',
'846': 'cat',
'847': 'cat',
'848': 'horse',
'849': 'automobile',
'850': 'bird',
'851': 'cat',
'852': 'dog',
'853': 'dog',
'854': 'dog',
'855': 'frog',
'856': 'automobile',
'857': 'deer',
'858': 'cat',
'859': 'horse',
'860': 'ship',
'861': 'ship',
'862': 'cat',
'863': 'frog',
'864': 'frog',
'865': 'bird',
'866': 'cat',
'867': 'airplane',
'868': 'truck',
'869': 'deer',
'870': 'cat',
'871': 'ship',
'872': 'airplane',
'873': 'airplane',
'874': 'automobile',
'875': 'automobile',
'876': 'dog',
'877': 'deer',
'878': 'truck',
'879': 'cat',
'880': 'automobile',
'881': 'ship',
'882': 'truck',
'883': 'cat',
'884': 'truck',
'885': 'truck',
'886': 'bird',
'887': 'truck',
'888': 'deer',
'889': 'ship',
'890': 'bird',
'891': 'truck',
'892': 'ship',
'893': 'ship',
'894': 'automobile',
'895': 'dog',
'896': 'cat',
'897': 'frog',
'898': 'ship',
'899': 'horse',
'900': 'frog',
'901': 'truck',
'902': 'ship',
'903': 'airplane',
'904': 'frog',
'905': 'deer',
'906': 'airplane',
'907': 'airplane',
'908': 'bird',
'909': 'dog',
'910': 'ship',
'911': 'bird',
'912': 'airplane',
'913': 'bird',
'914': 'horse',
'915': 'frog',
'916': 'truck',
'917': 'horse',
'918': 'automobile',
'919': 'dog',
'920': 'dog',
'921': 'frog',
'922': 'frog',
'923': 'cat',
'924': 'frog',
'925': 'bird',
'926': 'deer',
'927': 'horse',
'928': 'airplane',
'929': 'dog',
'930': 'frog',
'931': 'deer',
'932': 'frog',
'933': 'dog',
'934': 'bird',
'935': 'deer',
'936': 'frog',
'937': 'automobile',
'938': 'frog',
'939': 'airplane',
'940': 'deer',
'941': 'airplane',
'942': 'cat',
'943': 'automobile',
'944': 'ship',
'945': 'dog',
'946': 'deer',
'947': 'deer',
'948': 'automobile',
'949': 'horse',
'950': 'cat',
'951': 'truck',
'952': 'deer',
'953': 'horse',
'954': 'truck',
'955': 'horse',
'956': 'cat',
'957': 'horse',
'958': 'bird',
'959': 'ship',
'960': 'deer',
'961': 'frog',
'962': 'frog',
'963': 'automobile',
'964': 'bird',
'965': 'truck',
'966': 'airplane',
'967': 'deer',
'968': 'ship',
'969': 'horse',
'970': 'cat',
'971': 'truck',
'972': 'ship',
'973': 'horse',
'974': 'horse',
'975': 'airplane',
'976': 'bird',
'977': 'deer',
'978': 'automobile',
'979': 'automobile',
'980': 'deer',
'981': 'automobile',
'982': 'dog',
'983': 'deer',
'984': 'airplane',
'985': 'dog',
'986': 'frog',
'987': 'bird',
'988': 'ship',
'989': 'dog',
'990': 'airplane',
'991': 'bird',
'992': 'automobile',
'993': 'cat',
'994': 'dog',
'995': 'horse',
'996': 'cat',
'997': 'dog',
'998': 'automobile',
'999': 'cat',
'1000': 'dog'}
将验证集从原始的训练集中拆分出来
# 在pytorch中有一个比较简单但很常用的加载数据的方式就是先将文件夹创建好,然后文件夹名字为label,然后将这个label的训练数据放进去
# 这个函数的作用就是创建子文件夹,然后将图片搬过去
def copyfile(filename, target_dir):
"""文件复制到目标目录"""
os.makedirs(target_dir, exist_ok=True)
shutil.copy(filename, target_dir)
# 根目录:train_valid_test。下面有train文件夹,包含训练数据。valid包含验证数据,train_valid原始的train文件夹
def reorg_train_valid(data_dir, labels, valid_ratio):
n = collections.Counter(labels.values()).most_common()[-1][1]
n_valid_per_label = max(1, math.floor(n * valid_ratio))
label_count = {}
for train_file in os.listdir(os.path.join(data_dir, 'train')):
label = labels[train_file.split('.')[0]]
fname = os.path.join(data_dir, 'train', train_file)
copyfile(
fname,
os.path.join(data_dir, 'train_valid_set', 'train_valid', label))
if label not in label_count or label_count[label] < n_valid_per_label:
copyfile(
fname,
os.path.join(data_dir, 'train_valid_test', 'valid', label))
else:
copyfile(
fname,
os.path.join(data_dir, 'train_valid_test', 'train', label))
return n_valid_per_label
在预测期间整理测试集,以方便读取
def reorg_test(data_dir):
for test_file in os.listdir(os.path.join(data_dir, 'test')):
copyfile(
os.path.join(data_dir, 'test', test_file),
os.path.join(data_dir, 'train_valid_test', 'test', 'unknown'))
调用前面定义的函数
def reorg_cifar10_data(data_dir, valid_ratio):
labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
reorg_train_valid(data_dir, labels, valid_ratio)
reorg_test(data_dir)
batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_cifar10_data(data_dir, valid_ratio)
图像增广
transform_train = torchvision.transforms.Compose([
torchvision.transforms.Resize(40), # 将图片放大到40*40
torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0),
ratio=(1.0, 1.0)), # 随机裁剪
torchvision.transforms.RandomHorizontalFlip(), # 水平调整
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],
[0.2023, 0.1994, 0.2010])]) # 对RGB三个channel
transform_test = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],
[0.2023, 0.1994, 0.2010])])
读取由原始图像组成的数据集
train_ds, train_valid_ds = [
torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'train_valid_test', folder),
transform=transform_train) for folder in ['train', 'train_valid']]
valid_ds, test_ds = [
torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'train_valid_test', folder),
transform=transform_test) for folder in ['valid', 'test']]
指定上面定义的所有图像增广操作
train_iter, train_valid_iter = [
torch.utils.data.DataLoader(dataset, batch_size, shuffle=True,
drop_last=True) # drop_last表示如果最后一个批量大小不够的话,就丢掉
for dataset in (train_ds, train_valid_ds)]
valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False,
drop_last=True)
test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False,
drop_last=False)
模型
def get_net():
num_classes = 10
net = d2l.resnet18(num_classes, 3) # 3的意思就是RGB三通道
return net
loss = nn.CrossEntropyLoss(reduction="none") # reduction=‘none’表示不要加起来
训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,
lr_decay): # lr_period,lr_decay,这里的意思就是每隔几次迭代学习率降低
trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9,
weight_decay=wd)
scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay) # 这个函数的作用就是每个多少个迭代,将学习率乘以lr_decay
num_batches, timer = len(train_iter), d2l.Timer()
legend = ['train loss', 'train acc']
if valid_iter is not None:
legend.append('valid acc')
animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
legend=legend)
net = nn.DataParallel(net, device_ids=devices).to(devices[0])
for epoch in range(num_epochs):
net.train()
metric = d2l.Accumulator(3)
for i, (features, labels) in enumerate(train_iter):
timer.start()
l, acc = d2l.train_batch_ch13(net, features, labels, loss,
trainer, devices)
metric.add(l, acc, labels.shape[0])
timer.stop()
if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
animator.add(
epoch + (i + 1) / num_batches,
(metric[0] / metric[2], metric[1] / metric[2], None))
if valid_iter is not None:
valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)
animator.add(epoch + 1, (None, None, valid_acc))
scheduler.step() # 这里scheduler的原因就是之前已经将trainer给了schdeuler
measures = (f'train loss {metric[0] / metric[2]:.3f}, '
f'train acc {metric[1] / metric[2]:.3f}')
if valid_iter is not None:
measures += f', valid acc {valid_acc:.3f}'
print(measures + f'\n{metric[2] * num_epochs / timer.sum():.1f}'
f' examples/sec on {str(devices)}')
训练和验证模型
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 2e-4, 5e-4
lr_period, lr_decay, net = 4, 0.9, get_net()
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,
lr_decay)
对测试集进行分类并提交结果
net, preds = get_net(), []
train(net, train_valid_iter, None, num_epochs, lr, wd, devices, lr_period,
lr_decay)
for X, _ in test_iter:
y_hat = net(X.to(devices[0]))
preds.extend(y_hat.argmax(dim=1).type(torch.int32).cpu().numpy())
sorted_ids = list(range(1, len(test_ds) + 1))
sorted_ids.sort(key=lambda x: str(x))
df = pd.DataFrame({'id': sorted_ids, 'label': preds})
df['label'] = df['label'].apply(lambda x: train_valid_ds.classes[x])
df.to_csv('submission.csv', index=False)
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
我们提供完整数据集的小规模样本
d2l.DATA_HUB['dog_tiny'] = (d2l.DATA_URL + 'kaggle_dog_tiny.zip',
'0cb91d09b814ecdc07b50f31f8dcad3e81d6a86d')
demo = True
if demo:
data_dir = d2l.download_extract('dog_tiny')
else:
data_dir = os.path.join('..', 'data', 'dog-breed-identification')
Downloading ../data/kaggle_dog_tiny.zip from http://d2l-data.s3-accelerate.amazonaws.com/kaggle_dog_tiny.zip...
整理数据
def reorg_dog_data(data_dir, valid_ratio):
labels = d2l.read_csv_labels(os.path.join(data_dir, 'labels.csv'))
d2l.reorg_train_valid(data_dir, labels, valid_ratio)
d2l.reorg_test(data_dir)
batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_dog_data(data_dir, valid_ratio)
图片增广
transform_train = torchvision.transforms.Compose([
torchvision.transforms.RandomResizedCrop(224, scale=(0.08, 1.0),
ratio=(3.0 / 4.0, 4.0 / 3.0)),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4,
saturation=0.4), # 图像的明亮度等等
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
transform_test = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224), # 从中心点copy一个224*224的图片
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
train_ds, train_valid_ds = [
torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'train_valid_test', folder),
transform=transform_train) for folder in ['train', 'train_valid']]
valid_ds, test_ds = [
torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'train_valid_test', folder),
transform=transform_test) for folder in ['valid', 'test']]
train_iter, train_valid_iter = [
torch.utils.data.DataLoader(dataset, batch_size, shuffle=True,
drop_last=True)
for dataset in (train_ds, train_valid_ds)]
valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False,
drop_last=True)
test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False,
drop_last=False)
微调预训练模型
# 这个函数的作用就是将除了最后一层以外的参数不变,拿过来
def get_net(devices):
finetune_net = nn.Sequential()
finetune_net.features = torchvision.models.resnet34(pretrained=True)
finetune_net.output_new = nn.Sequential(nn.Linear(1000, 256), nn.ReLU(),
nn.Linear(256, 120)) # 这里我们在原来的输出后面加了几层
finetune_net = finetune_net.to(devices[0])
for param in finetune_net.features.parameters():
param.requires_grad = False # 这里的意思就是将卷积层的参数固定住,不更新他了,所以设置为false
return finetune_net
计算损失
loss = nn.CrossEntropyLoss(reduction='none')
def evaluate_loss(data_iter, net, devices):
l_sum, n = 0.0, 0
for features, labels in data_iter:
features, labels = features.to(devices[0]), labels.to(devices[0])
outputs = net(features)
l = loss(outputs, labels)
l_sum += l.sum()
n += labels.numel()
return l_sum / n
训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,
lr_decay):
net = nn.DataParallel(net, device_ids=devices).to(devices[0])
trainer = torch.optim.SGD(
(param for param in net.parameters() if param.requires_grad), lr=lr,
momentum=0.9, weight_decay=wd) # 这里第一个参数的肆意就是将网络中需要更新的参数给他,不需要更新的参数我们就不更新了
scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)
num_batches, timer = len(train_iter), d2l.Timer()
legend = ['train loss']
if valid_iter is not None:
legend.append('valid loss')
animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
legend=legend)
for epoch in range(num_epochs):
metric = d2l.Accumulator(2)
for i, (features, labels) in enumerate(train_iter):
timer.start()
features, labels = features.to(devices[0]), labels.to(devices[0])
trainer.zero_grad()
output = net(features)
l = loss(output, labels).sum()
l.backward()
trainer.step()
metric.add(l, labels.shape[0])
timer.stop()
if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
animator.add(epoch + (i + 1) / num_batches,
(metric[0] / metric[1], None))
measures = f'train loss {metric[0] / metric[1]:.3f}'
if valid_iter is not None:
valid_loss = evaluate_loss(valid_iter, net, devices)
animator.add(epoch + 1, (None, valid_loss.detach()))
scheduler.step()
if valid_iter is not None:
measures += f', valid loss {valid_loss:.3f}'
print(measures + f'\n{metric[1] * num_epochs / timer.sum():.1f}'
f' examples/sec on {str(devices)}')
训练和验证模型
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 10, 1e-4, 1e-4
lr_period, lr_decay, net = 2, 0.9, get_net(devices)
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,
lr_decay)
对测试集分类
net = get_net(devices)
train(net, train_valid_iter, None, num_epochs, lr, wd, devices, lr_period,
lr_decay)
preds = []
for data, label in test_iter:
output = torch.nn.functional.softmax(net(data.to(devices[0])), dim=0)
preds.extend(output.cpu().detach().numpy())
ids = sorted(
os.listdir(os.path.join(data_dir, 'train_valid_test', 'test', 'unknown')))
with open('submission.csv', 'w') as f:
f.write('id,' + ','.join(train_valid_ds.classes) + '\n')
for i, output in zip(ids, preds):
f.write(
i.split('.')[0] + ',' + ','.join([str(num)
for num in output]) + '\n')