文章主要包含:官方数据集导入、自定义数据集,自定义网络结构,训练,训练后的模型使用
import torch
import torchvision
import torchsummary
import os
import numpy as np
import matplotlib.pyplot as plt
BATCH_SIZE = 64
#图像行列像素数量
IMAGE_ROW = 28
IMAGE_COL = 28
#数据根路径
DATA_SOURCE_DIR = "../datasets/MNIST/raw/"
TRANSFORM = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.,),(1.,))
])
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(DEVICE)
数据集说明:
参考网址 https://www.cnblogs.com/xianhan/p/9145966.html
数据集网址 http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte
[offset] [type] [value] [description]
0000 32 bit integer 0x00000801(2049) magic number (MSB first)
0004 32 bit integer 60000 number of items
0008 unsigned byte ?? label
0009 unsigned byte ?? label
........
xxxx unsigned byte ?? labeltrain-images-idx3-ubyte
[offset] [type] [value] [description]
0000 32 bit integer 0x00000803(2051) magic number
0004 32 bit integer 60000 number of images
0008 32 bit integer 28 number of rows
0012 32 bit integer 28 number of columns
0016 unsigned byte ?? pixel
0017 unsigned byte ?? pixel
........
xxxx unsigned byte ?? pixelt10k-labels-idx1-ubyte
[offset] [type] [value] [description]
0000 32 bit integer 0x00000801(2049) magic number (MSB first)
0004 32 bit integer 10000 number of items
0008 unsigned byte ?? label
0009 unsigned byte ?? label
........
xxxx unsigned byte ?? labelt10k-images-idx3-ubyte
[offset] [type] [value] [description]
0000 32 bit integer 0x00000803(2051) magic number
0004 32 bit integer 10000 number of images
0008 32 bit integer 28 number of rows
0012 32 bit integer 28 number of columns
0016 unsigned byte ?? pixel
0017 unsigned byte ?? pixel
........
xxxx unsigned byte ?? pixel
TRAIN_DATASETS = torchvision.datasets.MNIST(root="../datasets",train=True,download=True,transform=TRANSFORM)
TRAIN_LOADER = torch.utils.data.DataLoader(TRAIN_DATASETS,shuffle=True,batch_size=BATCH_SIZE)
TEST_DATASETS = torchvision.datasets.MNIST(root="../datasets",train=False,download=True,transform=TRANSFORM)
TEST_LOADER = torch.utils.data.DataLoader(TRAIN_DATASETS,shuffle=False,batch_size=BATCH_SIZE)
查看图片
img,label = TRAIN_DATASETS[0]
img = img.numpy()
plt.title(label)
plt.imshow(img[0])
torch官方解释文档(纯英文) https://pytorch.org/docs/stable/data.html
torch.utils.data.Dataset源码 https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#Dataset自定义DataSet基本结构
calss DataSets(torch.utils.data.Dataset):
def __init__(self):
super(DataSets,self).__init__()
pass
def __getitem__(self,idx):
pass
def __len__(self):
passstruct.unpack_from(fmt, buf,offset)
fmt: 内容解析格式 '>'or'<' + str(number) + 'B'or'b'or'I'or'i'
buf: 文件缓存
offset:指针偏移量
import struct
def decode_idx1_ubyte(idx1_ubyte_file):
with open(idx1_ubyte_file, 'rb') as fp:
bin_data = fp.read()
#解析头文件
fmt = ">II"
magic_number,label_number = struct.unpack_from(fmt, bin_data, 0)
offset = 8 #指针偏移量
print("magic number:0x{:0>8x}({})\tlabel number:{}".format(magic_number,magic_number,label_number))
fmt=">B"
label=[]
for idx in range(label_number):
label.append(struct.unpack_from(fmt,bin_data,offset+idx))
return label
def decode_idx3_ubyte(idx3_ubyte_file):
with open(idx3_ubyte_file, 'rb') as fp:
bin_data = fp.read()
#解析头文件
fmt = ">IIII"
magic_number,image_number,rows,cols = struct.unpack_from(fmt, bin_data, 0)
offset = 16 #指针偏移量
print("magic number:0x{:0>8x}({})\t image number:{}".format(magic_number, magic_number, image_number))
print("rows:{}\t columns:{}".format(rows, cols))
fmt='>'+str(rows*cols)+'B'
image=[]
for idx in range(image_number):
data = struct.unpack_from(fmt, bin_data, offset+idx*rows*cols)
data = np.array(data,dtype=np.uint8).reshape((rows, cols))
image.append(data)
image = np.array(image)
return image
class MyMNISTDataSets(torch.utils.data.Dataset):
def __init__(self,root=DATA_SOURCE_DIR,train=True,transform=None):
super(MyMNISTDataSets,self).__init__()
self.root = root
self.transform = transform
self.train = train
if self.train:
image_path = "train"
label_path = "train"
else:
image_path = "t10k"
label_path = "t10k"
image_path = image_path+"-images-idx3-ubyte"
label_path = label_path+"-labels-idx1-ubyte"
image_path = os.path.join(self.root,image_path)
label_path = os.path.join(self.root,label_path)
self.data, self.targets = decode_idx3_ubyte(image_path),decode_idx1_ubyte(label_path)
def __getitem__(self,idx):
data,label = self.data[idx], self.targets[idx]
label = torch.as_tensor(label,dtype=torch.int64)
if self.transform is not None:
data = self.transform(data)
data = data.type(torch.FloatTensor)
return data,label
def __len__(self):
return len(self.data)
TRAIN_DATASETS = MyMNISTDataSets(train=True,transform=TRANSFORM)
TRAIN_LOADER = torch.utils.data.DataLoader(TRAIN_DATASETS,shuffle=True,batch_size=BATCH_SIZE)
TEST_DATASETS = MyMNISTDataSets(train=False,transform=TRANSFORM)
TEST_LOADER = torch.utils.data.DataLoader(TRAIN_DATASETS,shuffle=False,batch_size=BATCH_SIZE)
查看图片
img,label = TRAIN_DATASETS[0]
img = img.numpy()
plt.title(label)
plt.imshow(img[0])
class LinearNet(torch.nn.Module):
def __init__(self):
super(LinearNet,self).__init__()
self.l1 = torch.nn.Linear(28*28,512)
self.l2 = torch.nn.Linear(512,256)
self.l3 = torch.nn.Linear(256,128)
self.l4 = torch.nn.Linear(128,64)
self.l5 = torch.nn.Linear(64,10)
def forward(self,x):
x = x.view(-1,IMAGE_ROW*IMAGE_COL)
x = torch.nn.functional.relu(self.l1(x))
x = torch.nn.functional.relu(self.l2(x))
x = torch.nn.functional.relu(self.l3(x))
x = torch.nn.functional.relu(self.l4(x))
y = self.l5(x)
return y
model = LinearNet()
model.to(DEVICE)
torchsummary.summary(model,(1,28,28))
class CNNNet(torch.nn.Module):
def __init__(self):
super(CNNNet,self).__init__()
self.conv1 = torch.nn.Conv2d(in_channels = 1,out_channels = 10,kernel_size=5)
self.conv2 = torch.nn.Conv2d(in_channels = 10,out_channels = 20,kernel_size=5)
self.pooling = torch.nn.MaxPool2d(2)
self.fc = torch.nn.Linear(in_features = 320,out_features = 10)
self.relu = torch.nn.ReLU()
def forward(self,x):
batch_size = x.size(0)
x = self.conv1(x)
x = self.pooling(x)
x = self.relu(x)
x = self.conv2(x)
x = self.pooling(x)
x = self.relu(x)
x = x.view(batch_size,-1)
x = self.fc(x)
return x
model = CNNNet()
model.to(DEVICE)
torchsummary.summary(model,(1,28,28))
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.5)
import sys
for epoch in range(2):
model.train()
running_loss = 0.0
for batch_idx,data in enumerate(TRAIN_LOADER):
inputs,target = data
inputs,target = inputs.to(DEVICE),target.to(DEVICE)
optimizer.zero_grad()
outputs = model(inputs)
target = target.squeeze()
loss = criterion(outputs,target)
loss.backward()
optimizer.step()
running_loss += loss.item()
if batch_idx % 50 == 49:
sys.stdout.write("epoch:{:2d}\t {}\t:{:.2%}\t loss:{:.2f}\t\r".format(epoch,"train",(batch_idx+1)/len(TRAIN_LOADER),running_loss/(batch_idx+1)))
sys.stdout.flush()
sys.stdout.write('\n')
sys.stdout.flush()
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch_idx,data in enumerate(TEST_LOADER):
inputs,target = data
inputs,target = inputs.to(DEVICE),target.to(DEVICE)
outputs = model(inputs)
target = target.squeeze()
_,predict = torch.max(outputs.data,dim=1)
total += target.size(0)
correct += (predict == target).sum().item()
if batch_idx % 50 == 49:
sys.stdout.write("epoch:{:2d}\t {}\t:{:.2%}\t accuracy:{:.2%}\t\r".format(epoch,"test",(batch_idx+1)/len(TEST_LOADER),correct/total))
sys.stdout.flush()
sys.stdout.write('\n')
sys.stdout.flush()
with torch.no_grad():
choice = np.random.randint(0,len(TEST_DATASETS))
inputs,target = TEST_DATASETS[choice]
inputs = torch.as_tensor( inputs.numpy().reshape((1,1,28,28)))
inputs,target = inputs.to(DEVICE),target.to(DEVICE)
outputs = model(inputs)
print(outputs)
_,predict = torch.max(outputs.data,dim=1)
plt.title(predict)
plt.imshow(inputs.to("cpu").numpy()[0,0])