python pytorch 自制数据集实现图像二分类(包含数据处理过程详细代码)

这是一篇亲试可行的文章,记录了我在完成对自制数据集进行二分类的过程。我很开心自己的能感受到自己的进步~话不多说,直接开始吧。

一、数据处理

我的数据集是两个文件的图片,由于文件的名字十分复杂不太好用,所以我用python对其进行了批量重命名。

file_path = "/Users/***/Tooth-Detection/data/"  # 原文件路径
new_path = "/Users/***/Tooth-Detection/data/"   # 新文件路径
def rename(path):
    i = 0
    filelist = os.listdir(path)
    for files in filelist[:500]:   # 取文件夹中前500个文件
        oldDirpath = os.path.join(path, files)
        # if os.path.isdir(oldDirpath):  # 如果文件夹中还有文件夹,则进行递归操作
        #     rename(oldDirpath)
        filename = os.path.splitext(files)[0]
        filetype = os.path.splitext(files)[1]
        newDirPath = os.path.join(new_path+'/teeth', 'Y'+str(i)+filetype)
        # print(filename, filetype)
        os.rename(oldDirpath, newDirPath)
        i += 1
    print("共有文件数目:", i)

rename(file_path)

效果图:
python pytorch 自制数据集实现图像二分类(包含数据处理过程详细代码)_第1张图片
然后我发现原始数据的后缀不一样,我又只改了名字,没改后缀,所以现在我写个脚本修改后缀名:

file_path = "/Users/***/Tooth-Detection/data/teeth/"
filelist = os.listdir(file_path)
for files in filelist:
    portion = os.path.splitext(files)
    if portion[1] == '.JPG':
        newname = portion[0] + '.jpg'
        print(newname)
        os.rename(files, newname)
print("finished!")

报如下错误:
在这里插入图片描述
然后查阅stack overflow 才知道原因,因为listdir returns a list of files without the path,所以我们修改文件名的时候,必须给路径才可以!
将上述倒数第二行代码进行修改:

os.rename(os.path.join(file_path, files), os.path.join(file_path, newname))

修改后效果如下:
python pytorch 自制数据集实现图像二分类(包含数据处理过程详细代码)_第2张图片
万里长征第一步就好啦~

二、数据集准备–打标签

我的目的是做一个二分类,将有牙齿的文件命名为Y+编号,没有牙齿的文件命名为N+编号。
目前没有牙齿共有210+图片,有牙齿的有500张图片。
现在制作训练集、验证集和测试集,由于我的图片是没有标签的,所以我需要自己打标签:
写一个脚本,遍历文件夹下的文件名,并加上其对应的标签,写在.txt文件中:

import os

def generate(dir, label):
    files = os.listdir(dir+'/teeth')
    with open(dir+'val.txt', 'a+') as f:
        for file in files:
            filename = os.path.split(file)[0]
            filetype = os.path.split(file)[1]
            print(filename, filetype)
            if filetype == '.txt':
                continue
            name = '/teeth' + '/' + file + ' ' + str(int(label)) + '\n'
            f.write(name)
    print("finished!")

img_path = '/Users/***/Tooth-Detection/data/val/'

if __name__ == '__main__':
    i = 1
    generate(img_path, i)

效果如下:
python pytorch 自制数据集实现图像二分类(包含数据处理过程详细代码)_第3张图片

三、数据集加载

因为我会用到pytorch,所以必须将数据弄在DataLoader下面。可以重写Dataset,也可以使用ImageFolder,后者更加简单。

参考:训练一个性别二分类网络
作者写的十分详细,所以我就不过多赘述了。过程中有一些报错需要自己搜索解决。
我印象中有一个错误是,列表的索引只能是整数,不能是字符串str,需要将label和image改为1和0。
然后还有一个错误是纬度不匹配,需要的四维,但只给了两维,这个时候你需要检查一下你的image和label是否写反了。

#%%

# 训练集
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline

data_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),

])

train_dataset = datasets.ImageFolder(root='/Users/***/Tooth-Detection/data/train',transform=data_transform)
train_dataloader = DataLoader(dataset=train_dataset,
                              batch_size=4,
                              shuffle=True,
                              num_workers=4)


def show_batch_images(sample_batch):
    labels_batch = sample_batch[1]
    images_batch = sample_batch[0]

    for i in range(4):
        label_ = labels_batch[i].item()
        image_ = np.transpose(images_batch[i], (1, 2, 0))
        ax = plt.subplot(1, 4, i + 1)
        ax.imshow(image_)
        ax.set_title(str(label_))
        ax.axis('off')
        plt.pause(0.01)


plt.figure()
for i_batch, sample_batch in enumerate(train_dataloader):
    show_batch_images(sample_batch)

    plt.show()


#%%

# 验证集
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline
# https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

# data_transform = transforms.Compose([
#     transforms.RandomResizedCrop(224),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                          std=[0.229, 0.224, 0.225])
# ])

data_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),

])

val_dataset = datasets.ImageFolder(root='/Users/***/Tooth-Detection/data/val',transform=data_transform)
val_dataloader = DataLoader(dataset=val_dataset,
                              batch_size=4,
                              shuffle=True,
                              num_workers=4)


def show_batch_images(sample_batch):
    labels_batch = sample_batch[1]
    images_batch = sample_batch[0]

    for i in range(4):
        label_ = labels_batch[i].item()
        image_ = np.transpose(images_batch[i], (1, 2, 0))
        ax = plt.subplot(1, 4, i + 1)
        ax.imshow(image_)
        ax.set_title(str(label_))
        ax.axis('off')
        plt.pause(0.01)


plt.figure()
for i_batch, sample_batch in enumerate(val_dataloader):
    show_batch_images(sample_batch)

    plt.show()


#%%

import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline
# https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

# data_transform = transforms.Compose([
#     transforms.RandomResizedCrop(224),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                          std=[0.229, 0.224, 0.225])
# ])

data_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),

])

test_dataset = datasets.ImageFolder(root='/Users/***/Tooth-Detection/data/test',transform=data_transform)
test_dataloader = DataLoader(dataset=test_dataset,
                              batch_size=4,
                              shuffle=True,
                              num_workers=4)


def show_batch_images(sample_batch):
    labels_batch = sample_batch[1]
    images_batch = sample_batch[0]

    for i in range(4):
        label_ = labels_batch[i].item()
        image_ = np.transpose(images_batch[i], (1, 2, 0))
        ax = plt.subplot(1, 4, i + 1)
        ax.imshow(image_)
        ax.set_title(str(label_))
        ax.axis('off')
        plt.pause(0.01)


plt.figure()
for i_batch, sample_batch in enumerate(test_dataloader):
    show_batch_images(sample_batch)

    plt.show()

四、模型的训练与测试

#%%

import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import copy
import matplotlib.pyplot as plt
import os
import torchnet

#%%

print('train_dataset: {}'.format(len(train_dataset)))
print('val_dataset: {}'.format(len(val_dataset)))
print('test_dataset: {}'.format(len(test_dataset)))

#%%

batch_size = 8
data_root = '/Users/***/Tooth-Detection/data/train'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')



#%%

# 载入训练模型 
model = models.squeezenet1_1(pretrained=True)
model.classifier[1] = nn.Conv2d(in_channels=512, out_channels=2, kernel_size=(1, 1), stride=(1, 1))
model.num_classes = 2
print(model)
model.to(device)

# 优化方法、损失函数
optimizer = optim.SGD(model.parameters(),lr=0.001, momentum=0.9)
loss_fc = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.StepLR(optimizer,10, 0.1)


#%%

## 训练
num_epoch = 2
# 训练日志保存
logfile_dir = '/Users/***/Tooth-Detection/log/'
# file_train_loss = open('/Users/lyndsey/NewScenery/ml-homework/Tooth-Detection/log/train_loss.txt', 'w')
# file_train_acc = open('/Users/lyndsey/NewScenery/ml-homework/Tooth-Detection/log/train_acc.txt', 'w')

# file_val_loss = open('/Users/lyndsey/NewScenery/ml-homework/Tooth-Detection/log/val_loss.txt', 'w')
# file_val_acc = open('/Users/lyndsey/NewScenery/ml-homework/Tooth-Detection/log/val_acc.txt', 'w')

acc_best_wts = model.state_dict()
best_acc = 0
iter_count = 0

for epoch in range(num_epoch):
    train_loss = 0.0
    train_acc = 0.0
    train_correct = 0
    train_total = 0

    val_loss = 0.0
    val_acc = 0.0
    val_correct = 0
    val_total = 0

    scheduler.step()
    for i, sample_batch in enumerate(train_dataloader):
        # print(sample_batch)
        inputs = sample_batch[0].to(device)
        labels = sample_batch[1].to(device)
        #print(inputs,labels)
        

        # 模型设置为train
        model.train()

        # forward
        outputs = model(inputs)

        # print(labels)
        # loss
        loss = loss_fc(outputs, labels)

        # forward update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 统计
        train_loss += loss.item()
        train_correct += (torch.max(outputs, 1)[1] == labels).sum().item()
        train_total += labels.size(0)

        print('iter:{}'.format(i))
        
        if i % 10 == 9:
            for sample_batch in val_dataloader:
                inputs = sample_batch[0].to(device)
                labels = sample_batch[1].to(device)

                model.eval()
                outputs = model(inputs)
                loss = loss_fc(outputs, labels)
                _, prediction = torch.max(outputs, 1)
                val_correct += ((labels == prediction).sum()).item()
                val_total += inputs.size(0)
                val_loss += loss.item()

            val_acc = val_correct / val_total
            print('[{},{}] train_loss = {:.5f} train_acc = {:.5f} val_loss = {:.5f} val_acc = {:.5f}'.format(
                epoch + 1, i + 1, train_loss / 100,train_correct / train_total, val_loss/len(val_dataloader),
                val_correct / val_total))
            if val_acc > best_acc:
                best_acc = val_acc
                acc_best_wts = copy.deepcopy(model.state_dict())

            with open(logfile_dir +'train_loss.txt','a') as f:
                f.write(str(train_loss / 100) + '\n')
            with open(logfile_dir +'train_acc.txt','a') as f:
                f.write(str(train_correct / train_total) + '\n')
            with open(logfile_dir +'val_loss.txt','a') as f:
                f.write(str(val_loss/len(val_dataloader)) + '\n')
            with open(logfile_dir +'val_acc.txt','a') as f:
                f.write(str(val_correct / val_total) + '\n')

            iter_count += 200
            
            train_loss = 0.0
            train_total = 0
            train_correct = 0
            val_correct = 0
            val_total = 0
            val_loss = 0


print('Train finish!')
# 保存模型
model_file = '/Users/***/Tooth-Detection/models/'
# with open(model_file+'/model_squeezenet_teeth_1.pth','a') as f:
#     torch.save(acc_best_wts,f)
torch.save(acc_best_wts, model_file + '/model_squeezenet_utk_face_1.pth')
print('Model save ok!')

#%%

#测试

model = models.squeezenet1_1(pretrained=True)
model.classifier[1] = nn.Conv2d(in_channels=512, out_channels=2, kernel_size=(1, 1), stride=(1, 1))
model.num_classes = 2
model.load_state_dict(torch.load(model_file+'/model_squeezenet_utk_face_1.pth', map_location='cpu'))
print(model)
model.eval()


test_dataloader = DataLoader(dataset=test_dataset, batch_size=4, shuffle=False, num_workers=4)

correct = 0
total = 0
acc = 0.0
for i, sample in enumerate(test_dataloader):
    inputs, labels = sample[0], sample[1]

    outputs = model(inputs)

    _, prediction = torch.max(outputs, 1)
    correct += (labels == prediction).sum().item()
    total += labels.size(0)

acc = correct / total
print('test finish, total:{}, correct:{}, acc:{:.3f}'.format(total, correct, acc))
#%%

我也是第一次自己来做数据集,第一次自己完成整个过程,我的耐心得到了进步,我解决问题的能力也得到了进步。继续加油呀~

你可能感兴趣的:(深度学习,笔记)