U-net医学图像分割

代码作者

一 数据集分割

from PIL import Image
import os

# For the Dataset register at : http://brainiac2.mit.edu/isbi_challenge/
# Download the corresponding data files
# Only 30 images are available with ground truth
# 6 Images are used for validation and are put into a seperate folder
img = Image.open('./train-volume.tif')
print('*************')

directory = './ISBI 2012/Train-Volume/'
if not os.path.exists(directory):
    os.makedirs(directory)

directory = './ISBI 2012/Val-Volume/'
if not os.path.exists(directory):
    os.makedirs(directory)
for i in range(30):
    try:
        img.seek(i)
        if i % 5 == 0:
            img.save('./ISBI 2012/Val-Volume/train-volume-%s.tif' % (i,))
        else:
            img.save('./ISBI 2012/Train-Volume/train-volume-%s.tif' % (i,))
    except EOFError:
        break
img = Image.open('./train-labels.tif')
directory = './ISBI 2012/Train-Labels/'
if not os.path.exists(directory):
    os.makedirs(directory)

directory = './ISBI 2012/Val-Labels/'
if not os.path.exists(directory):
    os.makedirs(directory)
for i in range(30):
    try:
        img.seek(i) # i frame
        if i % 5 == 0:
            img.save('./ISBI 2012/Val-Labels/train-labels-%s.tif' % (i,))
        else:
            img.save('./ISBI 2012/Train-Labels/train-labels-%s.tif' % (i,))
    except EOFError:
        break

img = Image.open('./test-volume.tif')
directory = './ISBI 2012/Test-Volume/'
if not os.path.exists(directory):
    os.makedirs(directory)
for i in range(30):
    try:
        img.seek(i)
        img.save('./ISBI 2012/Test-Volume/test-volume-%s.tif' % (i,))
    except EOFError:
        break

数据集使用的是ISBI2012细胞检测数据集,30张训练图像,选其中六张作为验证集,剩下作为训练集。由于ISBI2012训练数据比较少U-net,通过图像扭曲对数据进行augment。图像扭曲增加数据    code

二 数据集加载

import glob
from torch.utils import data
from PIL import Image
import torchvision
import numpy as np

class ISBIDataset(data.Dataset):

    def __init__(self, gloob_dir_train, gloob_dir_label, length, is_pad, eval, totensor):
        self.gloob_dir_train = gloob_dir_train
        self.gloob_dir_label = gloob_dir_label
        self.length = length
        self.crop = torchvision.transforms.CenterCrop(512)#得到期望512*512输出图像
        self.crop_nopad = torchvision.transforms.CenterCrop(324)#没有加padding得到324*324输出
        self.is_pad = is_pad
        self.eval = eval 
        self.totensor = totensor
        self.changetotensor = torchvision.transforms.ToTensor()

        self.rand_vflip = False
        self.rand_hflip = False
        self.rand_rotate = False
        self.angle = 0

    def __len__(self):
        'Denotes the total number of samples'
        return self.length

    def __getitem__(self, index):
        'Generates one sample of data'
        # files are sorted depending the last number in their filename
        # for example : "./ISBI 2012/Train-Volume/train-volume-*.tif"
        trainfiles = sorted(glob.glob(self.gloob_dir_train),
                            key=lambda name: int(name[self.gloob_dir_train.rfind('*'):
                                                      -(len(self.gloob_dir_train) - self.gloob_dir_train.rfind('.'))]))

        labelfiles = sorted(glob.glob(self.gloob_dir_label),
                            key=lambda name: int(name[self.gloob_dir_label.rfind('*'):
                                                      -(len(self.gloob_dir_label) - self.gloob_dir_label.rfind('.'))]))

        trainimg = Image.open(trainfiles[index])
        trainlabel = Image.open(labelfiles[index])


        if not self.eval:
            if self.rand_vflip:
                trainlabel = trainlabel.transpose(Image.FLIP_LEFT_RIGHT)
                trainimg = trainimg.transpose(Image.FLIP_LEFT_RIGHT)

            if self.rand_hflip:
                trainlabel = trainlabel.transpose(Image.FLIP_TOP_BOTTOM)
                trainimg = trainimg.transpose(Image.FLIP_TOP_BOTTOM)

            if self.rand_rotate:
                # Add padding to the image to remove black boarders when rotating
                # image is croped to true size later.
                trainimg = Image.fromarray(np.pad(np.asarray(trainimg), ((107, 107), (107, 107)), 'reflect'))
                trainlabel = Image.fromarray(np.pad(np.asarray(trainlabel), ((107, 107), (107, 107)), 'reflect'))

                trainlabel = trainlabel.rotate(self.angle)
                trainimg = trainimg.rotate(self.angle)
                # crop rotated image to true size
                trainlabel = self.crop(trainlabel)
                trainimg = self.crop(trainimg)


        # when padding is used, dont crop the label image
        if not self.is_pad:
            trainlabel = self.crop_nopad(trainlabel)

        if self.totensor:
            trainlabel = self.changetotensor(trainlabel).long()
            trainimg = self.changetotensor(trainimg)

        return trainimg, trainlabel

三 U-net模型

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


# 1 MODEL
class Unet(nn.Module):

    def __init__(self):
        super(Unet, self).__init__()

        # All layers which have weights are created and initlialitzed in init.
        # parameterless modules are used in functional style F. in forward
        # (object version of parameterless modules can be created with nn. init too )

        # https://pytorch.org/docs/master/nn.html#conv2d
        # in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=0)

        # https://pytorch.org/docs/master/nn.html#batchnorm2d
        # num_features/channels, eps, momentum, affine, track_running_stats
        self.conv1_bn = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 64, 3, stride=1, padding=0)
        self.conv2_bn = nn.BatchNorm2d(64)

        # https://pytorch.org/docs/master/nn.html#maxpool2d
        # kernel_size, stride, padding, dilation, return_indices, ceil_mode
        self.maxPool1 = nn.MaxPool2d(2, stride=2, padding=0)

        self.conv3 = nn.Conv2d(64, 128, 3, stride=1, padding=0)
        self.conv3_bn = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 128, 3, stride=1, padding=0)
        self.conv4_bn = nn.BatchNorm2d(128)
        self.maxPool2 = nn.MaxPool2d(2, stride=2, padding=0)

        self.conv5 = nn.Conv2d(128, 256, 3, stride=1, padding=0)
        self.conv5_bn = nn.BatchNorm2d(256)
        self.conv6 = nn.Conv2d(256, 256, 3, stride=1, padding=0)
        self.conv6_bn = nn.BatchNorm2d(256)
        self.maxPool3 = nn.MaxPool2d(2, stride=2, padding=0)

        self.conv7 = nn.Conv2d(256, 512, 3, stride=1, padding=0)
        self.conv7_bn = nn.BatchNorm2d(512)
        self.conv8 = nn.Conv2d(512, 512, 3, stride=1, padding=0)
        self.conv8_bn = nn.BatchNorm2d(512)
        self.maxPool4 = nn.MaxPool2d(2, stride=2, padding=0)

        self.conv9 = nn.Conv2d(512, 1024, 3, stride=1, padding=0)
        self.conv9_bn = nn.BatchNorm2d(1024)
        self.conv10 = nn.Conv2d(1024, 1024, 3, stride=1, padding=0)
        self.conv10_bn = nn.BatchNorm2d(1024)

        # https://pytorch.org/docs/master/nn.html#convtranspose2d
        # in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias, dilation
        self.upsampconv1 = nn.ConvTranspose2d(1024, 512, 2, stride=2, padding=0)

        self.conv11 = nn.Conv2d(1024, 512, 3, stride=1, padding=0)
        self.conv11_bn = nn.BatchNorm2d(512)
        self.conv12 = nn.Conv2d(512, 512, 3, stride=1, padding=0)
        self.conv12_bn = nn.BatchNorm2d(512)

        self.upsampconv2 = nn.ConvTranspose2d(512, 256, 2, stride=2, padding=0)

        self.conv13 = nn.Conv2d(512, 256, 3, stride=1, padding=0)
        self.conv13_bn = nn.BatchNorm2d(256)
        self.conv14 = nn.Conv2d(256, 256, 3, stride=1, padding=0)
        self.conv14_bn = nn.BatchNorm2d(256)

        self.upsampconv3 = nn.ConvTranspose2d(256, 128, 2, stride=2, padding=0)

        self.conv15 = nn.Conv2d(256, 128, 3, stride=1, padding=0)
        self.conv15_bn = nn.BatchNorm2d(128)
        self.conv16 = nn.Conv2d(128, 128, 3, stride=1, padding=0)
        self.conv16_bn = nn.BatchNorm2d(128)

        self.upsampconv4 = nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0)

        self.conv17 = nn.Conv2d(128, 64, 3, stride=1, padding=0)
        self.conv17_bn = nn.BatchNorm2d(64)
        self.conv18 = nn.Conv2d(64, 64, 3, stride=1, padding=0)
        self.conv18_bn = nn.BatchNorm2d(64)

        self.conv19 = nn.Conv2d(64, 2, 1, stride=1, padding=0)
        self.conv19_bn = nn.BatchNorm2d(2)
        self.softmax = nn.Softmax2d()

        # weights can be initialized here:
        # for example:
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # force float division, therefore use 2.0
                # http://andyljones.tumblr.com/post/110998971763/an-explanation-of-xavier-initialization
                # https://arxiv.org/abs/1502.01852
                # a rectifying linear unit is zero for half of its input,
                # so you need to double the size of weight variance to keep the signals variance constant.
                # xavier would be: scalefactor * sqrt(2/ (inchannels + outchannels )
                std = math.sqrt(2.0/(m.kernel_size[0]*m.kernel_size[0]*m.in_channels))
                nn.init.normal_(m.weight, std=std)
                nn.init.constant_(m.bias, 0)
            # elif isinstance(m, nn.BatchNorm2d):
            #     print
            #    # nn.init.constant_(m.weight, 1)
            #     #nn.init.constant_(m.bias, 0)
            # elif isinstance(m, nn.ConvTranspose2d):
            #     print
            #    # nn.init.xavier_normal_(m.weight, 1)
            #    # do max pooling layers have weight? maybe can add bias.
            # elif isinstance(m, nn.MaxPool2d):
            #     print
            #     #nn.init.xavier_normal_(m.weight)

    def forward(self, x, padding=False):

        # https://pytorch.org/docs/master/nn.html#torch.nn.ReLU
        # https://pytorch.org/docs/master/nn.html#id26 F.relu
        # input, inplace
        # https://pytorch.org/docs/master/nn.html#torch.nn.functional.pad
        # input, pad , mode
        padmode = 'reflect'
        if padding:
            pad = (1, 1, 1, 1)
        else:
            pad = (0, 0, 0, 0)

        x = F.relu(self.conv1_bn(self.conv1(F.pad(x, pad, padmode))))
        x = F.relu(self.conv2_bn(self.conv2(F.pad(x, pad, padmode))))
        # save result for combination later
        x_copy1_2 = x
        x = self.maxPool1(x)

        x = F.relu(self.conv3_bn(self.conv3(F.pad(x, pad, padmode))))
        x = F.relu(self.conv4_bn(self.conv4(F.pad(x, pad, padmode))))
        x_copy3_4 = x
        x = self.maxPool2(x)

        x = F.relu(self.conv5_bn(self.conv5(F.pad(x, pad, padmode))))
        x = F.relu(self.conv6_bn(self.conv6(F.pad(x, pad, padmode))))
        x_copy5_6 = x
        x = self.maxPool3(x)

        x = F.relu(self.conv7_bn(self.conv7(F.pad(x, pad, padmode))))
        x = F.relu(self.conv8_bn(self.conv8(F.pad(x, pad, padmode))))
        # input, probability of an element to be zero-ed
        # https://pytorch.org/docs/master/nn.html#dropout
        x = F.dropout(x, 0.5)
        x_copy7_8 = x
        x = self.maxPool4(x)

        x = F.relu(self.conv9_bn(self.conv9(F.pad(x, pad, padmode))))
        x = F.relu(self.conv10_bn(self.conv10(F.pad(x, pad, padmode))))
        x = F.dropout(x, 0.5)
        x = F.relu(self.upsampconv1(x))

        x = self.crop_and_concat(x, x_copy7_8)

        x = F.relu(self.conv11_bn(self.conv11(F.pad(x, pad, padmode))))
        x = F.relu(self.conv12_bn(self.conv12(F.pad(x, pad, padmode))))

        x = F.relu(self.upsampconv2(x))

        x = self.crop_and_concat(x, x_copy5_6)

        x = F.relu(self.conv13_bn(self.conv13(F.pad(x, pad, padmode))))
        x = F.relu(self.conv14_bn(self.conv14(F.pad(x, pad, padmode))))

        x = F.relu(self.upsampconv3(x))

        x = self.crop_and_concat(x, x_copy3_4)

        x = F.relu(self.conv15_bn(self.conv15(F.pad(x, pad, padmode))))
        x = F.relu(self.conv16_bn(self.conv16(F.pad(x, pad, padmode))))

        x = F.relu(self.upsampconv4(x))

        x = self.crop_and_concat(x, x_copy1_2)

        x = F.relu(self.conv17_bn(self.conv17(F.pad(x, pad, padmode))))
        x = F.relu(self.conv18_bn(self.conv18(F.pad(x, pad, padmode))))

        x = F.relu(self.conv19_bn(self.conv19(x)))

        x = self.softmax(x)
        return x

    # when no padding is used, the upsampled image gets smaller
    # to copy a bigger image to the corresponding layer, it needs to be cropped
    def crop_and_concat(self, upsampled, bypass):
        # Python 2 / Integer division ( if int intputs ), // integer division
        c = (bypass.size()[2] - upsampled.size()[2]) // 2
        d = c
        # checks if bypass.size() is odd
        # if input image is 512, at   x = self.crop_and_concat(x, x_copy5_6)
        # x_copy5_6 is 121*121
        # therefore cut one more row and column
        if (bypass.size()[2] & 1) == 1:
            d = c + 1
            # padleft padright padtop padbottom
        bypass = F.pad(bypass, (-c, -d, -c, -d))
        return torch.cat((bypass, upsampled), 1) #为了定位准确,上半部分的特征(copy and crop之后)与上采样的输出相结合。

四 训练

import model
import model_bn
import torch.utils.data.dataloader as dl
import ISBI2012Data as ISBI
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import os
import numpy as np
from PIL import Image
import random

import time
import shutil
import argparse


# saves the model with learning rate and weight decay
def save_checkpoint(state, is_best, filename='checkpoint.pth'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best_lr_' + str(args.lr) + "_wd_" + str(args.weight_decay) + '.pth.tar')


def train(trainloader, model, criterion, optimizer, epoch):

    model.train()
    loss_sum = 0
    for i, data in enumerate(trainloader):

        # get train and label data
        train, label = data
        # put on gpu or cpu
        train = train.to(device)
        # label is of type TensorLong
        label = label.to(device)

        # for the CrossEntropyLoss:
        # outputs needs to be of size: Minibatch, classsize, dim 1, dim 2 , ...
        # outputs  are 2 classes with 2d images. channelsize = class size
        # label needs to be of format: Minibatch, dim 1, dim 2, ...
        # I cut the channel info for it to work, because it is only a 2d image.
        # As an alternative, one could add 1 channel for class in train, than label does not need any change
        # label normally looks like: ( minibatchsize, 1, width, height )
        label = label.view(label.size(0), label.size(2), label.size(3))

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        # if padding is true, one row and one column is added at left, right, top and bottom at each convolution
        # to maintain the original size of the image
        outputs = model(train, padding=args.pad)
        # the log is needed to calculate the crossentropy
        loss = criterion(torch.log(outputs), label)
        loss.backward()

        optimizer.step()

        running_loss = loss.item()
        loss_sum = loss_sum + running_loss

        # save the first minibatch image in each loop
        # to save every image,, just remove the & (i == 0) part
        if args.save_images & (i == 0):
            save_images(outputs, './val/' + str(args.lr) + '/', epoch, index=i)

        # delete all references to variables:
        # https://discuss.pytorch.org/t/tensor-to-vari
        del outputs, train, label, loss
    loss_avg = loss_sum / (i + 1)

    return loss_avg


def eval(valloader, model, criterion, save_image):

    # switch the model to eval mode ( important for dropout layers or batchnorm layers )
    model.eval()
    loss_sum = 0
    for i, data in enumerate(valloader):
        # get train and label data
        val, label = data
        # put on gpu or cpu
        val = val.to(device)
        # label is of type TensorLong
        label = label.to(device)

        # for the CrossEntropyLoss:
        # outputs needs to be of size: Minibatch, classsize, dim 1, dim 2 , ...
        # outputs  are 2 classes with 2d images. Channelsize = class size
        # label needs to be of format: Minibatch, dim 1, dim 2, ...
        # i cut the channel info for it to work, because it is only a 2d image.
        # as an alternative, one could add 1 channel for class in train, than label does not need any change
        # label normally looks like: ( minibatchsize, 1, width, height )
        label = label.view(label.size(0), label.size(2), label.size(3))

        # forward + backward + optimize
        outputs = model(val, padding=args.pad)
        loss = criterion(torch.log(outputs), label)
        running_loss = loss.item()
        loss_sum = loss_sum + running_loss

        if save_image:
            save_image(outputs, './eval/', 'eval', index=i)

        del outputs, val, label, loss

    loss_avg = loss_sum / (i + 1)
    return loss_avg


def save_images(outputs, directory, epoch, index):
    # copy first image in outputs back to cpu and save it
    x = outputs[0][0][:][:].cpu().detach().numpy()
    y = outputs[0][1][:][:].cpu().detach().numpy()

    # convert image to save it properly
    x = (x * 255).astype(np.uint8)
    y = (y * 255).astype(np.uint8)
    x = Image.fromarray(x)
    if not os.path.exists(directory):
        os.makedirs(directory)
    x.save(directory + 'class1_' + str(epoch) + '_image_' + str(index) + '.jpg')
    y = Image.fromarray(y)
    y.save(directory + 'class2_' + str(epoch) + '_image_' + str(index) + '.jpg')


# Parameters can be set at command-line

parser = argparse.ArgumentParser()
parser.add_argument('data', metavar='dataset', choices=['ISBI2012', 'CTC2015'],
                    help='ISBI2012 or CTC2015')
# parser.add_argument(-'batch-size', type=int,
#                     metavar='N', help='training image batch size')
parser.add_argument('-mbs', '--mini-batch-size', dest='minibatchsize', type=int,
                    metavar='N', default=1, help='mini batch size (default: 1). '
                                                 'For 8k memory on gpu, minibatchsize of 2-3 possible')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
                    help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=100, type=int, metavar='N',
                    help='number of total epochs to run (default: 100)')
parser.add_argument('-lr', default=0.001, type=float,
                    metavar='LR', help='initial learning rate (default: 0.001)')
parser.add_argument('--momentum', default=0.99, type=float, metavar='M',
                    help='momentum (default: 0.99)')
parser.add_argument('-es', '--epochsave', default=1, type=int, metavar='M',
                    help='save model every M epoch (default: 1)')
parser.add_argument('--weight-decay', '-wd', default=0, type=float,
                    metavar='W', help='weight decay (L2 penalty ) (default:0')
parser.add_argument('-r', '--resume', default='', type=str, metavar='PATH',
                    help='relative path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('-s', '--save-images', action='store_true',
                    help='save the first image of output each epoche')
parser.add_argument('-c', '--cpu', action ='store_true',
                    help='use cpu instead of gpu')
parser.add_argument('-p', '--pad', action ='store_true',
                    help='use padding at each 3x3 convolution to maintain image size')
parser.add_argument('-txt', action ='store_true',
                    help='save console output in txt')
parser.add_argument('-bn', action ='store_true',
                    help='use u-net with batchnorm layers added after each convolution')


args = parser.parse_args()
args.start_epoch = 0
best_loss = 10

# use same seed for testing purpose
torch.manual_seed(999)
random.seed(999)

print ("***** Starting Programm *****")


# 1: design model

# check if cuda is available

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if args.cpu:
    device = torch.device("cpu")

# .to(device) sends the data to the given device ( cuda or cpu )
if args.bn:
    model = model_bn.Unet().to(device)
else:
    model = model.Unet().to(device)

# use cudnn for better speed, if available
if device.type == "cuda":
    cudnn.benchmark = True

# 2: Construct loss and optimizer

# Using a softmax layer at the end, applying the log and using NLLoss()
# has the same loss as using no softmax layer, and calculating the CrossEntropyLoss()
# the difference is in the output image of the model.
# If you want to use the CrossEntropyLoss(), remove the softmax layer, and  the torch.log() at the loss

# criterion = nn.CrossEntropyLoss().to(device)
criterion = nn.NLLLoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

# Reduce learning rate when a metric has stopped improving, needs to be activated in epoch too
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True)

# load the ISBI 2012 training data
# the CTC2015 Datasetloader is not finished yet
# The length of the Dataset has to be set by yourself
# self, gloob_dir_train, gloob_dir_label, length, is_pad, eval, totensor):
if args.data == "ISBI2012":
    trainset = ISBI.ISBIDataset(
        "./ISBI 2012/Train-Volume/train-volume-*.tif", "./ISBI 2012/Train-Labels/train-labels-*.tif",
        length=24, is_pad=args.pad, eval=False, totensor=True)

    valset = ISBI.ISBIDataset(
        "./ISBI 2012/Val-Volume/train-volume-*.tif", "./ISBI 2012/Val-Labels/train-labels-*.tif",
        length=6, is_pad=args.pad, eval=True, totensor=True)
elif args.data == "CTC2015":
    trainset = ISBI.ISBIDataset(
        "./ISBI 2012/Train-Volume/train-volume-*.tif", "./ISBI 2012/Train-Labels/train-labels-*.tif",
        length=24, is_pad=args.pad, eval=False, totensor=True)

    valset = ISBI.ISBIDataset(
        "./ISBI 2012/Val-Volume/train-volume-*.tif", "./ISBI 2012/Val-Labels/train-labels-*.tif",
        length=6, is_pad=args.pad, eval=True, totensor=True)

# num of workers can represent the number of cores in cpu, pinned memory is page-locked memory
# disable it  if system freezes, or swap is used a lot
# https://discuss.pytorch.org/t/what-is-the-disadvantage-of-using-pin-memory/1702
# batchsize is 1 for validation, to get a single output for loss and not a mean

trainloader = dl.DataLoader(trainset, batch_size=args.minibatchsize,  num_workers=args.workers, pin_memory=True)
valloader = dl.DataLoader(valset, batch_size=1,  num_workers=args.workers, pin_memory=True)

# 3: Training cycle forward, backward , update

# load the model if set
if args.resume:
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_loss = checkpoint['best_loss']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        # scheduler.last_epoch = args.start_epoch
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(args.resume, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

# print some info for console
print('Dataset      : ' + str(args.data))
print('Start Epoch  : ' + str(args.start_epoch))
print('End Epoch    : ' + str(args.epochs))
print('Learning rate: ' + str(args.lr))
print('Momentum     : ' + str(args.momentum))
print('Weight decay : ' + str(args.weight_decay))
print('Use padding  : ' + str(args.pad))

#  save a txt file with the console info
if args.txt:
    with open("Info_lr_" + str(args.lr) + "_wd_" + str(args.weight_decay) + ".txt", "a") as myfile:
        myfile.write('Dataset      : ' + str(args.data))
        myfile.write('\n')
        myfile.write('Start Epoch  : ' + str(args.start_epoch))
        myfile.write('\n')
        myfile.write('End Epoch    : ' + str(args.epochs))
        myfile.write('\n')
        myfile.write('Learning rate: ' + str(args.lr))
        myfile.write('\n')
        myfile.write('Momentum     : ' + str(args.momentum))
        myfile.write('\n')
        myfile.write('Weight decay : ' + str(args.weight_decay))
        myfile.write('\n')
        myfile.write('Use padding  : ' + str(args.pad))
        myfile.write('\n')
        myfile.close()

if args.evaluate:
    print (" avg loss: " + str(eval(valloader, model, criterion, True)))
else:
    for epoch in range(args.start_epoch, args.epochs):
        start_time = time.time()
        train_loss = train(trainloader, model, criterion, optimizer, epoch)
        val_loss = eval(valloader, model, criterion, False)
        end_time = time.time()

        print('Epoch [%5d] train_loss: %.4f val_loss: %.4f loop time: %.5f' %
              (epoch + 1, train_loss, val_loss, end_time - start_time))
        if args.txt:
            with open("Info_lr_" + str(args.lr) + "_wd_" + str(args.weight_decay) + ".txt", "a") as myfile:
                myfile.write('Epoche [%5d] train_loss: %.4f val_loss: %.4f loop time: %.5f' %
                             (epoch + 1, train_loss, val_loss, end_time - start_time))
                myfile.write('\n')
                myfile.close()

        # see info at criterion above
        # scheduler.step(val_loss)
        # Data Augmentation
        # 50% change to flip or random rotate, same for whole batch
        # change every epoch
        # starting epoche with no flipping and rotation
        trainloader.dataset.rand_vflip = random.random() < 0.5
        trainloader.dataset.rand_hflip = random.random() < 0.5
        #rotate image
        #trainloader.dataset.rand_rotate = random.random() < 0.5
        #trainloader.dataset.angle = random.uniform(-180, 180)

        #save best loss
        is_best_loss = val_loss < best_loss
        best_loss = min(val_loss, best_loss)
        # save model
        if (epoch + 1) % args.epochsave == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_loss': best_loss,
                'optimizer': optimizer.state_dict(),
            }, is_best_loss, filename='checkpoint.'+ str(args.lr) + "wd" + str(args.weight_decay) + '.pth.tar')
print ("*****   End  Programm   *****")

How to use:

First download the ISBI 2012 Dataset in your folder. Start ISBI_split.py which generates a folder structure.
for all options run:
main.py - h

for a simple start use:
main.py ISBI2012 -s -p -txt

-s saves the first image each epoche
-p uses padding to stop the reduction of image size caused by 3x3 convolutions
-txt save information about the used settings and losses each epoch

the txt file looks for example like this:
Dataset : ISBI2012
Start Epoch : 0
End Epoch : 100
Learning rate: 0.001
Momentum : 0.99
Weight decay : 0
Use padding : True
Epoche [ 1] train_loss: 0.4911 val_loss: 0.4643 loop time: 9.96429
Epoche [ 2] train_loss: 0.4630 val_loss: 0.5017 loop time: 5.41091
Epoche [ 3] train_loss: 0.4460 val_loss: 0.4637 loop time: 5.45516




你可能感兴趣的:(深度学习)