import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import sampler
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms
import torchvision
import numpy as np
import os
cwd = os.getcwd()
from PIL import Image
import time
import copy
import random
import cv2
import re
import shutil
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
import skimage
import json
from tqdm import tqdm
import base64
## Define data augmentation and transforms
mean_nums=[0.485, 0.456, 0.406]
std_nums=[0.229, 0.224, 0.225]
chosen_transforms = {'train': transforms.Compose([
transforms.RandomResizedCrop(size=227),
transforms.RandomRotation(degrees=10),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ColorJitter(brightness=0.15, contrast=0.15),
transforms.ToTensor(),
transforms.Normalize(mean_nums, std_nums)
]), 'val': transforms.Compose([
transforms.Resize(227),
transforms.CenterCrop(227),
transforms.ToTensor(),
transforms.Normalize(mean_nums, std_nums)
]),
}
def inverse_transform(tensor):
for t, m, s in zip(tensor,mean_nums,std_nums):
t.mul_(s).add_(m)
return tensor
class Crack_DataSet(Dataset):
def __init__(self,indexes,Type,normalize=None):
### NumofSamples:为数据样本数量
###Type为训练集or测试集or验证集
super(Crack_DataSet, self).__init__()
self.indexes = indexes
self.Type = Type
self.normalize = normalize
def __len__(self):
return len(self.indexes)
def __getitem__(self, index):
if self.Type == 'train' or self.Type == 'validation':
negative_img_path = 'CrackDetection/Negative/'
positive_img_path = 'CrackDetection/Positive/'
img_list = [[os.path.join(negative_img_path,i),0] for i in os.listdir(negative_img_path)]+[[os.path.join(positive_img_path,i),1] for i in os.listdir(positive_img_path)]
image = Image.open(img_list[self.indexes[index]][0])
if self.normalize is not None:
image= self.normalize(image)
label = img_list[self.indexes[index]][1]
return np.asanyarray(image,dtype=np.float32),label
else:
pass
# DataLoader中collate_fn使用
def crack_dataset_collate(batch):
images = []
labels = []
for img, label in batch:
images.append(img)
labels.append(label)
images = np.array(images)
return images, labels
NumberOfSamples = 40000
Train_ratio = 0.8
DataIndexes = [i for i in range(NumberOfSamples)]
random.shuffle(DataIndexes)
TrainIndexes = DataIndexes[:int(NumberOfSamples*Train_ratio)]
TrainDataset = Crack_DataSet(TrainIndexes,"train",chosen_transforms['train'])
ValidationIndexes = DataIndexes[int(NumberOfSamples*Train_ratio):]
ValidationDataset = Crack_DataSet(ValidationIndexes,"validation",chosen_transforms['val'])
dataset_sizes = {}
dataset_sizes['train'] = len(TrainIndexes)
dataset_sizes['val'] = len(ValidationIndexes)
print(dataset_sizes)
{'train': 32000, 'val': 8000}
## Set code to run on device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda
## Load pretrained model
resnet50 = models.resnet50(pretrained=True)
# Freeze model parameters
for param in resnet50.parameters():
param.requires_grad = False
## Change the final layer of the resnet model
# Change the final layer of ResNet50 Model for Transfer Learning
fc_inputs = resnet50.fc.in_features
resnet50.fc = nn.Sequential(
nn.Linear(fc_inputs, 128),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(128, 2)
)
# Convert model to be used on GPU
resnet50 = resnet50.to(device)
# from torchsummary import summary
# print(summary(resnet50, (3, 227, 227)))
# Define Optimizer and Loss Function
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet50.parameters())
# optimizer = optim.SGD(resnet50.fc.parameters(), lr=0.001, momentum=0.9)
# Decay LR by a factor of 0.1 every 3 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
num_workers = 0
batch_size = 64
dataloaders = {}
dataloaders["train"] = DataLoader(TrainDataset, shuffle = True, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
drop_last=True, collate_fn=crack_dataset_collate)
dataloaders["val"] = DataLoader(ValidationDataset, shuffle = True, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
drop_last=True, collate_fn=crack_dataset_collate)
def train_model(model, criterion, optimizer, scheduler, num_epochs=10):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
current_loss = 0.0
current_corrects = 0
# Here's where the training happens
print('Iterating through data...')
# print(len(dataloaders[phase]))
for inputs, labels in tqdm(dataloaders[phase]):
inputs = torch.tensor(inputs).to(device)
# inputs = inputs.squeeze()#去除为1的维度
#数据形状为【batch,channel,width,height】
labels = torch.tensor(labels).to(device)
# labels = labels.squeeze()
# We need to zero the gradients, don't forget it
optimizer.zero_grad()
# Time to carry out the forward training poss
# We only need to log the loss stats if we are in training phase
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# We want variables to hold the loss statistics
current_loss += loss.item() * inputs.size(0)
current_corrects += torch.sum(preds == labels.data)
epoch_loss = current_loss / dataset_sizes[phase]
epoch_acc = current_corrects.double() / dataset_sizes[phase]
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc))
# Make a copy of the model if the accuracy on the validation set has improved
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
if phase == 'train':
scheduler.step()
time_since = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_since // 60, time_since % 60))
print('Best val Acc: {:4f}'.format(best_acc))
# Now we'll load in the best model weights and return it
model.load_state_dict(best_model_wts)
return model
class_names = ['normal','crack']
# 获得一批训练数据
inputs, classes = next(iter(dataloaders['val']))
inputs = torch.as_tensor(inputs)
plt.figure(figsize=(15,6))
for ii, inp in enumerate(inputs):
inp = inverse_transform(inp)
inp = inp.permute(1,2,0)
plt.subplot(4, 8, ii+1)
plt.imshow(inp)
plt.title(class_names[classes[ii]])
plt.axis('off')
if ii+1 == 16:
break
plt.tight_layout() # 画完图之后再适应间距
def visualize_model(model, num_images=6):
was_training = model.training
model.eval()
images_handeled = 0
plt.figure(figsize=(12,8))
with torch.no_grad():
for i, (inputs, labels) in enumerate(dataloaders['val']):
inputs = torch.as_tensor(inputs).to(device)
labels = torch.as_tensor(labels).to(device)
inputs = inputs.squeeze()#去除为1的维度
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
for j in range(inputs.size()[0]):
images_handeled += 1
plt.subplot(num_images//2, 2, images_handeled)
plt.axis('off')
plt.title('predicted: {}'.format(class_names[preds[j]]))
inp = torch.as_tensor(inputs.cpu().data[j])
inp = inverse_transform(inp)
inp = inp.permute(1,2,0)
plt.imshow(inp)
if images_handeled == num_images:
model.train(mode=was_training)
return
model.train(mode=was_training)
# base_model = train_model(resnet50, criterion, optimizer, exp_lr_scheduler, num_epochs=10)
# torch.save(base_model,'base_model.pth')
base_model = torch.load('base_model.pth')
visualize_model(base_model)
plt.show()
def predict(model, test_image, print_class = False):
transform = chosen_transforms['val']
test_image_tensor = transform(test_image)
test_image_tensor = torch.tensor(np.array(test_image_tensor,dtype=np.float32))
# test_image_tensor = test_image_tensor.transpose(0,2)
if torch.cuda.is_available():
test_image_tensor = test_image_tensor.view(1, 3, 227, 227).cuda()
else:
test_image_tensor = test_image_tensor.view(1, 3, 227, 227)
with torch.no_grad():
model.eval()
# Model outputs log probabilities
out = model(test_image_tensor)
_, preds = torch.max(out, 1)
class_name = class_names[preds.item()]
if print_class:
print("Output class : ", class_name)
return class_name
def predict_on_crops(input_image, height=227, width=227, save_crops = False):
im = cv2.imread(input_image)
imgheight, imgwidth, channels = im.shape
k=0
output_image = np.zeros_like(im)
for i in range(0,imgheight,height):
for j in range(0,imgwidth,width):
a = im[i:i+height, j:j+width] # 图片的227*227 正方形
## discard image cropss that are not full size
predicted_class = predict(base_model,Image.fromarray(a),print_class=False) # 去预测那一小块
## save image
file, ext = os.path.splitext(input_image)
image_name = file.split('/')[-1]
folder_name = 'out_' + image_name
## Put predicted class on the image
if predicted_class == 'crack':
color = (0,0, 255)
else:
color = (0, 255, 0)
# 在那一小块图片上添加文字
cv2.putText(a, predicted_class, (50,50), cv2.FONT_HERSHEY_SIMPLEX , 0.7, color, 1, cv2.LINE_AA)
b = np.zeros_like(a, dtype=np.uint8)
b[:] = color
add_img = cv2.addWeighted(a, 0.9, b, 0.1, 0)
## Save crops
if save_crops:
if not os.path.exists(os.path.join('real_images', folder_name)):
os.makedirs(os.path.join('real_images', folder_name))
filename = os.path.join('real_images', folder_name,'img_{}.png'.format(k))
cv2.imwrite(filename, add_img)
output_image[i:i+height, j:j+width,:] = add_img
k+=1
## Save output image
cv2.imwrite(os.path.join('CrackDetection','predictions', folder_name+ '.jpg'), output_image)
return output_image
plt.figure(figsize=(10,10))
output_image = predict_on_crops('CrackDetection/real/p1.jpg', 128, 128)
plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB))
plt.figure(figsize=(10,10))
output_image = predict_on_crops('CrackDetection/real/p2.jpg')
plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB))
plt.figure(figsize=(10,10))
output_image = predict_on_crops('CrackDetection/real/p3.jpg', 128,128)
plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB))
plt.figure(figsize=(10,10))
output_image = predict_on_crops('CrackDetection/real/p4.jpg',128,128)
plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB))