pytorch项目Student-Teacher anomaly detection修改训练方式,能将大图划分成块载入

项目代码:

https://github.com/denguir/student-teacher-anomaly-detection

其实也可以直接随机crop大图区域,然后再crop-patch(65*65)

但是这里我们将大图划分成4块区域去做了

before_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))
print("before_dir", before_dir)    
dataset = AnomalyDataset(csv_file=os.path.join(before_dir, 'data/{}/{}.csv'.format(DATASET, DATASET)),
                             root_dir=os.path.join(before_dir, 'data/{}/img'.format(DATASET)),
                             transform=transforms.Compose([
                                 # transforms.Grayscale(num_output_channels=3),
                                 # transforms.Resize((imH, imW)),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.RandomVerticalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
                             type='train',
                             label=0)
import os
import numpy as np
import pandas as pd
import torch
from PIL import Image
from einops import rearrange
from torchvision import transforms, utils
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
import cv2

def cut_image(image):
    width, height = image.size
    item_width = int(width / 2)
    item_height = int(height / 2)
    box_list = []
    # (left, upper, right, lower)
    for i in range(0,2):#两重循环,生成4张图片基于原图的位置
        for j in range(0,2):
            #print((i*item_width,j*item_height,(i+1)*item_width,(j+1)*item_height))
            box = (j*item_width,i*item_height,(j+1)*item_width,(i+1)*item_height)
            box_list.append(box)
    image_list = [image.crop(box) for box in box_list]
    return image_list

class AnomalyDataset(Dataset):
    '''Anomaly detection dataset'''

    def __init__(self, csv_file, root_dir, transform=None, **constraint):
        super(AnomalyDataset, self).__init__()
        self.root_dir = root_dir
        self.transform = transform
        self.frame_list = self._get_dataset(csv_file, constraint)

        imH = 576
        imW = 768
        self.resize = transforms.Compose([transforms.Resize((imH, imW))])
    
    def _get_dataset(self, csv_file, constraint):
        '''Apply filter based on the contraint dict on the dataset'''
        df = pd.read_csv(csv_file)
        df = df.loc[(df[list(constraint)] == pd.Series(constraint)).all(axis=1)]
        return df
    
    def __len__(self):
        return len(self.frame_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir, self.frame_list.iloc[idx]['image_name'])
        label = self.frame_list.iloc[idx]['label']

        image_array = cv2.imread(img_name, -1)
        #cv2.cvtColor()

        image = Image.fromarray(image_array.astype('uint8')).convert('RGB')
        #image2 = Image.open(img_name)

        image = self.resize(image)

        image_list = cut_image(image)

        # for m_key, m_val in enumerate(image_list):
        #     m_val.save('./result_{}.png'.format(m_key))

        sample = {'image': [], 'label': []}

        for m_key, m_val in enumerate(image_list):
            sample['image'].append(self.transform(m_val))
            sample['label'].append(label)

        # sample = {'image': image, 'label': label}
        #
        # if self.transform:
        #     sample['image'] = self.transform(image)
        return sample


if __name__ == '__main__':
    import matplotlib.pyplot as plt
    import sys 
    
    DATASET = "mydata"
    dataset = AnomalyDataset(csv_file=f'../data/{DATASET}/{DATASET}.csv',
                                   root_dir=f'../data/{DATASET}/img',
                                   transform=transforms.Compose([
                                       #transforms.Grayscale(num_output_channels=3),
                                       transforms.Resize((256, 256)),
                                       transforms.RandomCrop((256, 256)),
                                       transforms.ToTensor()]),
                                    type='train',
                                    label=0)
    
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)
    
    for i, batch in enumerate(dataloader):
        print(i, batch['image'].size(), batch['label'].size())
        # display 3rd batch
        if i == 3:
            n = np.random.randint(0, len(batch['label']))

            image = rearrange(batch['image'][n, :, :, :], 'c h w -> h w c')
            label = batch['label'][n]

            plt.title(f"Sample #{n} - {'Anomalous' if label else 'Normal'}")
            plt.imshow(image)
            plt.show()
            break
            for i, batch in tqdm(enumerate(dataloader)):
                # zero the parameters gradient
                optimizers[j].zero_grad()

                # forward pass
                # inputs = batch['image'].to(device)

                for m_val in range(len(batch['image'])):

                    inputs = batch['image'][m_val].to(device)

                    with torch.no_grad():
                        targets = (teacher(inputs) - t_mu) / torch.sqrt(t_var)
                    outputs = student(inputs)
                    loss = student_loss(targets, outputs)

                    # backward pass
                    loss.backward()
                    optimizers[j].step()
                    running_loss += loss.item()

 

你可能感兴趣的:(深度学习)