from __future__ import print_function
from PIL import Image #从文件加载图像(python Image Library)
import os #文件操作
import sys #文件操作
import numpy as np #与torch混合使用搭建数据传输
import argparse #处理命令行参数的库
import torch.utils.data as data ##创建数据集
#水果数据预处理
class Fruit(data.Dataset):
#初始化,定义数据内容和标签
def __init__(self, root_dir, train=True, transform=None):
self.root_dir = os.path.abspath(root_dir)
self.transform = transform
self.train=train
if (self.train):
self.data = np.load(os.path.join(self.root_dir, "train_data.npy"))
self.labels = np.load(os.path.join(self.root_dir, "train_labels.npy"))
else:
self.data = np.load(os.path.join(self.root_dir, "validation_data.npy"))
self.labels = np.load(os.path.join(self.root_dir, "validation_labels.npy"))
self.data = self.data.transpose((0, 2, 3, 1))#转换底层编号
#查找数据和标签
def __getitem__(self, index):
# img, target = self.data[index], self.labels[index]
#img = Image.fromarray(img.astype('uint8'))
img = self.data[index]
target = self.labels[index]
if self.transform is not None:
img = self.transform(img)
return img, target
#给出数据集的大小
def __len__(self):
return (len(self.data))
## 引入函数库
import argparse
import os
import sys
import numpy as np
import cv2
import glob
print ("INFO: all the modules are imported.")
##功能是把你的输入参数打印到屏幕
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, required=True, help='Path to the dataset folder')
args = parser.parse_args()
##To load 64 of 94 kinds fruit from fruit-360
fruit_names = [
'AppleBraeburn',
'AppleGolden1',
'AppleGolden2',
'AppleGolden3',
'AppleGrannySmith',
'AppleRed1',
'AppleRed2',
'AppleRed3',
'AppleRedDelicious',
'AppleRedYellow1',
'AppleRedYellow2',
'Apricot',
'Avocado',
'Avocadoripe',
'Banana',
'BananaLadyFinger',
'BananaRed',
'Cactusfruit',
'Cantaloupe1',
'Cantaloupe2',
'Carambula',
'Cherry1',
'Cherry2',
'CherryRainier',
'CherryWaxBlack',
'CherryWaxRed',
'CherryWaxYellow',
'Chestnut',
'Clementine',
'Cocos',
'Dates',
'Granadilla',
'GrapeBlue',
'GrapefruitPink',
'GrapefruitWhite',
'GrapePink',
'GrapeWhite',
'GrapeWhite2',
'GrapeWhite3',
'GrapeWhite4',
'Guava',
'Hazelnut',
'Huckleberry',
'Kaki',
'Kiwi',
'Kumquats',
'Lemon',
'LemonMeyer',
'Limes',
'Lychee',
'Mandarine',
'Mango',
'Mangostan',
'Maracuja',
'MelonPieldeSapo',
'Mulberry',
'Nectarine',
'Orange',
'Papaya',
'PassionFruit',
'Peach',
'Peach2',
'PeachFlat',
'Pear',
# 'PearAbate',
# 'PearKaiser',
# 'PearMonster',
# 'PearWilliams',
# 'Pepino',
# 'Physalis',
# 'PhysaliswithHusk',
# 'Pineapple',
# 'PineappleMini',
# 'PitahayaRed',
# 'Plum',
# 'Plum2',
# 'Plum3',
# 'Pomegranate',
# 'PomeloSweetie',
# 'Quince',
# 'Rambutan',
# 'Raspberry',
# 'Salak',
# 'Strawberry',
# 'StrawberryWedge',
# 'Tamarillo',
# 'Tangelo',
# 'Tomato1',
# 'Tomato2',
# 'Tomato3',
# 'Tomato4',
# 'TomatoCherryRed',
# 'TomatoMaroon',
# 'Walnut'
]
image_path = args.dataset
print ("INFO: Training image path is : {}".format(image_path))
## Creation of training data.
train_data = []
train_labels = []
#n= 0
for fruit in fruit_names:
print (fruit)
folder_path = os.path.join(image_path, "Training", fruit)
images = os.listdir(folder_path)
for i in range(len(images)):
final_path = os.path.join(folder_path, images[i])
img = cv2.imread(final_path, cv2.IMREAD_COLOR)
dims = np.shape(img)
img = np.reshape(img, (dims[2], dims[0], dims[1]))
train_data.append(img)
train_labels.append(fruit_names.index(fruit))
#train_labels.append(int(n))
#n+=1
train_data = np.array(train_data)
print (train_data.shape)
train_labels = np.array(train_labels)
print (train_labels.shape)
print ("OK: Training data created.")
### saving the data into a file.
np.save('train_data.npy', train_data)
check = np.load('train_data.npy')
np.save('train_labels.npy', train_labels)
check2 = np.load('train_labels.npy')
print (check.shape)
print (check2.shape)
validation_data = []
validation_labels = []
#n=0
for fruit in fruit_names:
print (fruit)
folder_path = os.path.join(image_path, "Test", fruit)
images = os.listdir(folder_path)
for i in range(len(images)):
final_path = os.path.join(folder_path, images[i])
if not os.path.isfile(final_path):
print ("This path doeesn't exist : {}".format(final_path))
continue
img = cv2.imread(final_path, cv2.IMREAD_COLOR)
dims = np.shape(img)
img = np.reshape(img, (dims[2], dims[0], dims[1]))
validation_data.append(img)
validation_labels.append(fruit_names.index(fruit))
# validation_labels.append(int(n))
#n+=1
validation_data = np.array(validation_data)
print (validation_data.shape)
validation_labels = np.array(validation_labels)
print (validation_labels.shape)
print ("OK: Validation data created.")
### saving the data into a file.
np.save('validation_data.npy', validation_data)
check = np.load('validation_data.npy')
np.save('validation_labels.npy', validation_labels)
check2 = np.load('validation_labels.npy')
#
print (check.shape)
print (check2.shape)
print (len(fruit_names))
##网络搭建
#定义卷积神经网络
import torch
import torch.nn as nn
import torchvision
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader #迭代器,方便多线程读取数据
import argparse
import fruit_data
import torch.nn.functional as F
from torch.autograd import Variable #Variable是最核心的变量
from sklearn.metrics import accuracy_score
#判断当前是gpu还是cpu
device = torch.device("cuda") if (torch.cuda.is_available()) else torch.device("cpu")
print (device)
#搭建网络,定义网络单元
#Net (
# (conv1): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1))
# (conv2): Conv2d(64, 64, kernel_size=(7, 7), stride=(1, 1))
# (conv3): Conv2d(64, 64, kernel_size=(7, 7), stride=(1, 1))
# (fc1): Linear (64 -> 120)
# (fc2): Linear (120 -> 64)
#)
class FruitNet(nn.Module):
def __init__(self):
super(FruitNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 5) #3通道,64个输出,5x5平方卷积(kerner_size)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(64,64, kernel_size=7, stride=1)
self.pool2 = nn.MaxPool2d(3)
self.conv3 = nn.Conv2d(64,64, kernel_size=7)
self.pool3 = nn.MaxPool2d(5)
self.linear1 = nn.Linear(64, 120)# an affine operation: y = Wx + b
self.linear2 = nn.Linear(120, 64)
def forward(self, x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = self.pool3(F.relu(self.conv3(x)))
#view函数将张量x变形成一维向量形式,总特征数不变,为全连接层做准备
x = x.view(x.size(0), -1)
x = F.relu(self.linear1(x))
x = F.relu(self.linear2(x))
return x
##训练网络
def train_network(dataloader_train):
net = FruitNet()
net = net.to(device)
##定义损失函数和优化器:学习率(修改来决定执行速度)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
losses = []
for epoch in range(5): #全部训练集训练6次:epoch=[0,1,2,3,4,5]
current_loss = 0.0
print ("Epoch : {}".format(epoch + 1))
for i_batch, (images, labels) in enumerate(dataloader_train):
#get the inputs
images, labels = images.to(device), labels.to(device)
#warp them in Variable
x = Variable(images, requires_grad=False).float()
y = Variable(labels, requires_grad=False).long()
x = x.to(device)
y = y.to(device)
#zero the parameter gradiebts
optimizer.zero_grad()
# forward + backward + optimize
y_pred = net(x)
correct = y_pred.max(1)[1].eq(y).sum()
print ("INFO: Number of correct items classified : {}".format(correct.item()))
#loss
loss = criterion(y_pred, y)
print ("Loss : {}".format(loss.item()))
#backward
current_loss += loss.item()
loss.backward()
#update weights
optimizer.step()
losses.append(current_loss)
## Save the network.
torch.save(net.state_dict(), "model/fruit_model_state_dict.pth")
torch.save(optimizer.state_dict(), "model/fruit_model_optimizer_dict.pth")
print ("OK: Finished training for {} epochs".format(epochs))
return losses, net
def test_network(net, dataloader_test):
net.eval()
criterion = nn.CrossEntropyLoss()
accuracies = []
with torch.no_grad():
for feature, label in dataloader_test:
feature = feature.to(device)
label = label.to(device)
pred = net(feature)
accuracy = accuracy_score(label.cpu().data.numpy(), pred.max(1)[1].cpu().data.numpy()) * 100
print ("Accuracy : ", accuracy)
loss = criterion(pred, label)
print ("Loss : {}".format(loss.item()))
accuracies.append(accuracy)
total = 0.0
for j in range(len(accuracies)):
total = total + accuracies[j]
avg_acc = total / len(accuracies)
print ("OK: testing done with overall accuracy is : {}".format(avg_acc))
#DataLoader生成batch,其中参数:
#dataset:Dataset类型,从其中加载数据
#batch_size:int,可选。每个batch加载多少样本
#shuffle:bool,可选。为True时表示每个epoch都对数据进行洗牌
#sampler:Sampler,可选。从数据集中采样样本的方法。
#num_workers:int,可选。加载数据时使用多少子进程。默认值为0,表示在主进程中加载数据。
#drop_last:bool,可选。True表示如果最后剩下不完全的batch,丢弃。False表示不丢弃。
def main():
#将读入的数据进行转化:数据分布归一化到[-1,1]
root_dir = args.data_dir
data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
transformed_dataset = fruit_data.Fruit(root_dir, train=True, transform=data_transform)
dataloader_train = DataLoader(transformed_dataset, batch_size=64, shuffle=True, num_workers=4)
transformed_test_dataset = fruit_data.Fruit(root_dir, train=False, transform=data_transform)
dataloader_test = DataLoader(transformed_test_dataset, batch_size=64, shuffle=True, num_workers=4)
dataiter = iter(dataloader_train)
images, labels = dataiter.next()
print ("INFO: image shape is {}".format(images.shape))
print ("INFO: Tensor type is : {}".format(images.type()))
print ("INFO: labels shape is : {}".format(labels.shape))
losses, net = train_network(dataloader_train)
test_network (net, dataloader_test)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data-dir', type=str, required=True, help="Dataset directory where npy files are stored")
parser.add_argument('--epochs', type=int, required=False, default=10, help="Number of epochs")
args = parser.parse_args()
epochs = args.epochs
main()
链接:https://pan.baidu.com/s/1z-RerDtL0ehzdEeVGIHIAA