Pytorch迁移学习版的裂缝检测(resnet50)

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import sampler
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms
import torchvision
import numpy as np
import os
cwd = os.getcwd()
from PIL import Image
import time
import copy
import random
import cv2
import re
import shutil
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
import skimage
import json
from tqdm import tqdm
import base64
## Define data augmentation and transforms

mean_nums=[0.485, 0.456, 0.406]
std_nums=[0.229, 0.224, 0.225]
chosen_transforms = {'train': transforms.Compose([
        transforms.RandomResizedCrop(size=227),
        transforms.RandomRotation(degrees=10),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ColorJitter(brightness=0.15, contrast=0.15),
        transforms.ToTensor(),
        transforms.Normalize(mean_nums, std_nums)
]), 'val': transforms.Compose([
        transforms.Resize(227),
        transforms.CenterCrop(227),
        transforms.ToTensor(),
        transforms.Normalize(mean_nums, std_nums)
]),
}

def inverse_transform(tensor):
    for t, m, s in zip(tensor,mean_nums,std_nums):
        t.mul_(s).add_(m)
    return tensor

create pytorch dataset

class Crack_DataSet(Dataset):
    def __init__(self,indexes,Type,normalize=None):
        ### NumofSamples:为数据样本数量
        ###Type为训练集or测试集or验证集
        super(Crack_DataSet, self).__init__()
        self.indexes = indexes
        self.Type = Type
        self.normalize = normalize
    
    def __len__(self):
        return len(self.indexes)
    
    def __getitem__(self, index):
        if self.Type == 'train' or self.Type == 'validation':
            negative_img_path = 'CrackDetection/Negative/'
            positive_img_path = 'CrackDetection/Positive/'
            img_list = [[os.path.join(negative_img_path,i),0] for i in os.listdir(negative_img_path)]+[[os.path.join(positive_img_path,i),1] for i in os.listdir(positive_img_path)]
            image = Image.open(img_list[self.indexes[index]][0])
            if self.normalize is not None:
                image= self.normalize(image)
            label = img_list[self.indexes[index]][1]
            return np.asanyarray(image,dtype=np.float32),label
        else:
            pass

        
# DataLoader中collate_fn使用
def crack_dataset_collate(batch):
    images = []
    labels = []
    for img, label in batch:
        images.append(img)
        labels.append(label)
    images = np.array(images)
    return images, labels

Create Train and Val Data sets

NumberOfSamples = 40000
Train_ratio = 0.8
DataIndexes = [i for i in range(NumberOfSamples)]
random.shuffle(DataIndexes)

Create training dataset

TrainIndexes = DataIndexes[:int(NumberOfSamples*Train_ratio)]
TrainDataset = Crack_DataSet(TrainIndexes,"train",chosen_transforms['train'])

create validation dataset

ValidationIndexes = DataIndexes[int(NumberOfSamples*Train_ratio):]
ValidationDataset = Crack_DataSet(ValidationIndexes,"validation",chosen_transforms['val'])
dataset_sizes = {}
dataset_sizes['train'] = len(TrainIndexes)
dataset_sizes['val'] = len(ValidationIndexes)
print(dataset_sizes)
{'train': 32000, 'val': 8000}
## Set code to run on device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda
## Load pretrained model
resnet50 = models.resnet50(pretrained=True)

# Freeze model parameters
for param in resnet50.parameters():
    param.requires_grad = False

## Change the final layer of the resnet model
# Change the final layer of ResNet50 Model for Transfer Learning
fc_inputs = resnet50.fc.in_features
 
resnet50.fc = nn.Sequential(
    nn.Linear(fc_inputs, 128),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(128, 2)
)

# Convert model to be used on GPU
resnet50 = resnet50.to(device)

# from torchsummary import summary
# print(summary(resnet50, (3, 227, 227)))

Set up Pretrained Model

# Define Optimizer and Loss Function
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet50.parameters())
# optimizer = optim.SGD(resnet50.fc.parameters(), lr=0.001, momentum=0.9)
# Decay LR by a factor of 0.1 every 3 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
num_workers = 0
batch_size = 64
dataloaders = {}
dataloaders["train"] = DataLoader(TrainDataset, shuffle = True, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
                                    drop_last=True, collate_fn=crack_dataset_collate)
dataloaders["val"] = DataLoader(ValidationDataset, shuffle = True, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
                                    drop_last=True, collate_fn=crack_dataset_collate)

Start training

def train_model(model, criterion, optimizer, scheduler, num_epochs=10):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            current_loss = 0.0
            current_corrects = 0

            # Here's where the training happens
            print('Iterating through data...')
#             print(len(dataloaders[phase]))
            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = torch.tensor(inputs).to(device)
                # inputs = inputs.squeeze()#去除为1的维度
                
                #数据形状为【batch,channel,width,height】
                labels = torch.tensor(labels).to(device)
                # labels = labels.squeeze()

                # We need to zero the gradients, don't forget it
                optimizer.zero_grad()

                # Time to carry out the forward training poss
                # We only need to log the loss stats if we are in training phase
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        

                # We want variables to hold the loss statistics
                current_loss += loss.item() * inputs.size(0)
                current_corrects += torch.sum(preds == labels.data)

            epoch_loss = current_loss / dataset_sizes[phase]
            epoch_acc = current_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # Make a copy of the model if the accuracy on the validation set has improved
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'train':
                scheduler.step()


    time_since = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_since // 60, time_since % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # Now we'll load in the best model weights and return it
    model.load_state_dict(best_model_wts)
    return model

class_names = ['normal','crack']

# 获得一批训练数据
inputs, classes = next(iter(dataloaders['val']))
inputs = torch.as_tensor(inputs)
plt.figure(figsize=(15,6))
for ii, inp in enumerate(inputs):
    inp = inverse_transform(inp)
    inp = inp.permute(1,2,0)
    plt.subplot(4, 8, ii+1)
    plt.imshow(inp)
    plt.title(class_names[classes[ii]])
    plt.axis('off')
    if ii+1 == 16:
        break
plt.tight_layout()  # 画完图之后再适应间距

Pytorch迁移学习版的裂缝检测(resnet50)_第1张图片

def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_handeled = 0
    plt.figure(figsize=(12,8))
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = torch.as_tensor(inputs).to(device)
            labels = torch.as_tensor(labels).to(device)
            inputs = inputs.squeeze()#去除为1的维度
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
           
            for j in range(inputs.size()[0]):
                images_handeled += 1
                plt.subplot(num_images//2, 2, images_handeled)
                plt.axis('off')
                plt.title('predicted: {}'.format(class_names[preds[j]]))
                
                inp = torch.as_tensor(inputs.cpu().data[j])
                inp = inverse_transform(inp)
                inp = inp.permute(1,2,0)
                plt.imshow(inp)
                
                if images_handeled == num_images:
                    model.train(mode=was_training)
                    return
        
        model.train(mode=was_training)
# base_model = train_model(resnet50, criterion, optimizer, exp_lr_scheduler, num_epochs=10)
# torch.save(base_model,'base_model.pth')
base_model = torch.load('base_model.pth')

visualize_model(base_model)
plt.show()

Pytorch迁移学习版的裂缝检测(resnet50)_第2张图片

Inference

def predict(model, test_image, print_class = False):
    
    transform = chosen_transforms['val']
    test_image_tensor = transform(test_image)
    test_image_tensor = torch.tensor(np.array(test_image_tensor,dtype=np.float32))  
#     test_image_tensor = test_image_tensor.transpose(0,2)
    if torch.cuda.is_available():
        test_image_tensor = test_image_tensor.view(1, 3, 227, 227).cuda()
    else:
        test_image_tensor = test_image_tensor.view(1, 3, 227, 227)

    with torch.no_grad():
        model.eval()
        # Model outputs log probabilities
        out = model(test_image_tensor)
        _, preds = torch.max(out, 1)


        class_name = class_names[preds.item()]
        if print_class:
            print("Output class :  ", class_name)
    return class_name
def predict_on_crops(input_image, height=227, width=227, save_crops = False):
    im = cv2.imread(input_image)
    imgheight, imgwidth, channels = im.shape
    k=0
    output_image = np.zeros_like(im)
    for i in range(0,imgheight,height):
        for j in range(0,imgwidth,width):
            a = im[i:i+height, j:j+width]  # 图片的227*227 正方形
            ## discard image cropss that are not full size
            predicted_class = predict(base_model,Image.fromarray(a),print_class=False)  # 去预测那一小块
            ## save image
            file, ext = os.path.splitext(input_image)  
            image_name = file.split('/')[-1]
            folder_name = 'out_' + image_name
            ## Put predicted class on the image
            if predicted_class == 'crack':
                color = (0,0, 255)
            else:
                color = (0, 255, 0)
            # 在那一小块图片上添加文字
            cv2.putText(a, predicted_class, (50,50), cv2.FONT_HERSHEY_SIMPLEX , 0.7, color, 1, cv2.LINE_AA) 
            b = np.zeros_like(a, dtype=np.uint8)
            b[:] = color
            add_img = cv2.addWeighted(a, 0.9, b, 0.1, 0)
            ## Save crops
            if save_crops:
                if not os.path.exists(os.path.join('real_images', folder_name)):
                    os.makedirs(os.path.join('real_images', folder_name))
                filename = os.path.join('real_images', folder_name,'img_{}.png'.format(k))
                cv2.imwrite(filename, add_img)
            output_image[i:i+height, j:j+width,:] = add_img
            k+=1
    ## Save output image
    cv2.imwrite(os.path.join('CrackDetection','predictions', folder_name+ '.jpg'), output_image)
    return output_image
plt.figure(figsize=(10,10))
output_image = predict_on_crops('CrackDetection/real/p1.jpg', 128, 128)
plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB))

Pytorch迁移学习版的裂缝检测(resnet50)_第3张图片

plt.figure(figsize=(10,10))
output_image = predict_on_crops('CrackDetection/real/p2.jpg')
plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB))

Pytorch迁移学习版的裂缝检测(resnet50)_第4张图片

plt.figure(figsize=(10,10))
output_image = predict_on_crops('CrackDetection/real/p3.jpg', 128,128)
plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB))

Pytorch迁移学习版的裂缝检测(resnet50)_第5张图片

plt.figure(figsize=(10,10))
output_image = predict_on_crops('CrackDetection/real/p4.jpg',128,128)
plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB))

Pytorch迁移学习版的裂缝检测(resnet50)_第6张图片

你可能感兴趣的:(pytorch案例,pytorch,迁移学习,python)