使用vit预训练遥感数据得到分类模型

train.py

import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from torch import optim
import os
import csv
from PIL import Image
import warnings
warnings.simplefilter('ignore')
from torchvision.models import resnet18
import glob
import random
import newmodel

class Pokemon(Dataset):
    def __init__(self, root, resize, mode):  # root是文件路径,resize是对原始图片进行裁剪,mode是选择模式(train、test、validation)
        super(Pokemon, self).__init__()
        self.root = root
        self.resize = resize
        self.name2label = {}  # 给每个种类分配一个数字,以该数字作为这一类别的label
        # name是宝可梦的种类,e.g:pikachu
        for name in sorted(os.listdir(os.path.join(self.root))):  # listdir返回的顺序不固定,加上一个sorted使每一次的顺序都一样
            if not os.path.isdir(os.path.join(self.root, name)):  # os.path.isdir()用于判断括号中的内容是否是一个未压缩的文件夹
                continue
            self.name2label[name] = len(self.name2label.keys())#将所有的类和序号放在一个字典中

        print(self.name2label)

        self.images, self.labels = self.load_csv('images&labels.csv')#读图片路径和标签
        # 将全部数据分成train、validation、test
        if mode == 'train':  # 前60%作为训练集
            self.images = self.images[:int(0.6 * len(self.images))]
            self.labels = self.labels[:int(0.6 * len(self.labels))]
        elif mode == 'val':  # 60%~80%作为validation
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
        else:  # 后20%作为test set
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.labels)):]

    def load_csv(self, filename):
        # 载入原始图片的路径,并保存到指定的CSV文件中,然后从该CSV文件中再次读入所有图片的存储路径和label。
        # 如果CSV文件已经存在,则直接读入该CSV文件的内容
        # 为什么保存的是图片的路径而不是图片?因为直接保存图片可能会造成内存爆炸

        if not os.path.exists(os.path.join(self.root, filename)):  # 如果filename这个文件不存在,那么执行以下代码,创建file
            images = []
            for name in self.name2label.keys():
                # glob.glob()返回的是括号中的路径中的所有文件的路径
                # += 是把glob.glob()返回的结果依次append到image中,而不是以一个整体append
                # 这里只用了png/jpg/jepg是因为本次实验的图片只有这三种格式,如果有其他格式请自行添加

                images += glob.glob(os.path.join(self.root, name, '*.jpg'))

            print(len(images))
            random.shuffle(images)  # 把所有图片路径顺序打乱
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:  # 将图片路径及其对应的数字标签写到指定文件中
                writer = csv.writer(f)
                for img in images:  # img e.g:'./pokemon/pikachu\\00000001.png'
                    name = img.split(os.sep)[-2]  # 即取出‘pikachu’
                    label = self.name2label[name]  # 根据name找到对应的数字标签
                    writer.writerow([img, label])  # 把每张图片的路径和它对应的数字标签写到指定的CSV文件中
                print('image paths and labels have been writen into csv file:', filename)

        # 把数据读入(如果filename存在就直接执行这一步,如果不存在就先创建file再读入数据)
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img, label = row
                label = int(label)

                images.append(img)
                labels.append(label)


        assert len(images) == len(labels)  # 确保它们长度一致

        return images, labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img, label = self.images[idx], self.labels[idx]  # 此时img还是路径字符串,要把它转化成tensor
        # 将图片resize成224*224,并转化成tensor,这个tensor的size是3*224*224(3是因为有RGB3个通道)
        trans = transforms.Compose((
            lambda x: Image.open(x).convert('RGB'),
            transforms.Resize((self.resize, self.resize)),  # 必须要把长宽都一起写上啊!!!
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        # 这个数据是根据resnet中的图片统计得到的,直接拿来用就好
        ))
        img = trans(img)
        label = torch.tensor(label)
        return img, label

batch_size = 32
lr = 1e-3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1234)

train_db = Pokemon('pokemon',224,'train') #将所有图片(顺序已打乱)的前60%作为train_set
val_db = Pokemon('pokemon',224,'val')  #60%~80%作为validation_set
test_db = Pokemon('pokemon',224,'test') #80%~100%作为test_set
train_loader = DataLoader(train_db,batch_size=batch_size,shuffle=True) #之后调用一次train_loader就会把train_db划分成很多batch
val_loader = DataLoader(val_db,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(test_db,batch_size=batch_size,shuffle=True)

'''
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten,self).__init__()
    def forward(self,x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.reshape(-1,shape)
'''

#初始化模型
#trained_model = resnet18(pretrained = True) #拿到已经训练好的resnet18模型
#model = nn.Sequential(*list(trained_model.children())[:-1], #拿出resnet18的前面17层,输出的size是b*512*1*1
#                      Flatten(), #经过flatten之后的size是b*512
#                      nn.Linear(512,33)).to(device)

model=newmodel.Model()


print('模型需要训练的参数共有{}个'.format(sum(map(lambda p:p.numel(),model.parameters()))))

loss_fn = nn.CrossEntropyLoss() #选择loss_function

optimizer = optim.Adam(model.parameters(),lr=lr) #选择优化方式


def evaluate(model, loader):
    correct_num = 0
    total_num = len(loader.dataset)
    for img, label in loader:  # lodaer中包含了很多batch,每个batch有32张图片
        img, label = img.to(device), label.to(device)
        with torch.no_grad():
            logits = model(img)
            pre_label = logits.argmax(dim=1)
        correct_num += torch.eq(pre_label, label).sum().float().item()

    return correct_num / total_num


# 开始训练
best_epoch, best_acc = 0, 0
for epoch in range(200):  # 时间关系,我们只训练10个epoch
    for batch_num, (img, label) in enumerate(train_loader):
        # img.size [b,3,224,224]  label.size [b]
        img, label = img.to(device), label.to(device)
        logits = model(img)
        loss = loss_fn(logits, label)
        if batch_num % 5 == 0:
            print('这是第{}次迭代的第{}个batch,loss是{}'.format(epoch + 1, batch_num + 1, loss.item()))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    val_acc = evaluate(model, val_loader)
    # 如果val_acc比之前的好,那么就把该epoch保存下来,并把此时模型的参数保存到指定txt文件里
    if val_acc > best_acc:
        print('验证集上的准确率是:{}'.format(val_acc))
        best_epoch = epoch
        best_acc = val_acc
        torch.save(model.state_dict(), 'best.pth')

print('best_acc:{},best_epoch:{}'.format(best_acc, best_epoch))
'''
model.load_state_dict(torch.load('best.pth'))
# 开始检验
print('模型训练完毕,已将参数设置成训练过程中的最优值,现在开始测试test_set')
test_acc = evaluate(model, test_loader)
print('测试集上的准确率是:{}'.format(test_acc))

'''

test.py

import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from torch import optim
import os
import csv
from PIL import Image
import warnings
warnings.simplefilter('ignore')
from torchvision.models import resnet18
import glob
import random
import newmodel

class Pokemon(Dataset):
    def __init__(self, root, resize, mode):  # root是文件路径,resize是对原始图片进行裁剪,mode是选择模式(train、test、validation)
        super(Pokemon, self).__init__()
        self.root = root
        self.resize = resize
        self.name2label = {}  # 给每个种类分配一个数字,以该数字作为这一类别的label
        # name是宝可梦的种类,e.g:pikachu
        for name in sorted(os.listdir(os.path.join(self.root))):  # listdir返回的顺序不固定,加上一个sorted使每一次的顺序都一样
            if not os.path.isdir(os.path.join(self.root, name)):  # os.path.isdir()用于判断括号中的内容是否是一个未压缩的文件夹
                continue
            self.name2label[name] = len(self.name2label.keys())#将所有的类和序号放在一个字典中

        print(self.name2label)

        self.images, self.labels = self.load_csv('images&labels.csv')#读图片路径和标签
        # 将全部数据分成train、validation、test
        if mode == 'train':  # 前60%作为训练集
            self.images = self.images[:int(0.6 * len(self.images))]
            self.labels = self.labels[:int(0.6 * len(self.labels))]
        elif mode == 'val':  # 60%~80%作为validation
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
        else:  # 后20%作为test set
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.labels)):]

    def load_csv(self, filename):
        # 载入原始图片的路径,并保存到指定的CSV文件中,然后从该CSV文件中再次读入所有图片的存储路径和label。
        # 如果CSV文件已经存在,则直接读入该CSV文件的内容
        # 为什么保存的是图片的路径而不是图片?因为直接保存图片可能会造成内存爆炸

        if not os.path.exists(os.path.join(self.root, filename)):  # 如果filename这个文件不存在,那么执行以下代码,创建file
            images = []
            for name in self.name2label.keys():
                # glob.glob()返回的是括号中的路径中的所有文件的路径
                # += 是把glob.glob()返回的结果依次append到image中,而不是以一个整体append
                # 这里只用了png/jpg/jepg是因为本次实验的图片只有这三种格式,如果有其他格式请自行添加

                images += glob.glob(os.path.join(self.root, name, '*.jpg'))

            print(len(images))
            random.shuffle(images)  # 把所有图片路径顺序打乱
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:  # 将图片路径及其对应的数字标签写到指定文件中
                writer = csv.writer(f)
                for img in images:  # img e.g:'./pokemon/pikachu\\00000001.png'
                    name = img.split(os.sep)[-2]  # 即取出‘pikachu’
                    label = self.name2label[name]  # 根据name找到对应的数字标签
                    writer.writerow([img, label])  # 把每张图片的路径和它对应的数字标签写到指定的CSV文件中
                print('image paths and labels have been writen into csv file:', filename)

        # 把数据读入(如果filename存在就直接执行这一步,如果不存在就先创建file再读入数据)
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img, label = row
                label = int(label)

                images.append(img)
                labels.append(label)


        assert len(images) == len(labels)  # 确保它们长度一致

        return images, labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img, label = self.images[idx], self.labels[idx]  # 此时img还是路径字符串,要把它转化成tensor
        # 将图片resize成224*224,并转化成tensor,这个tensor的size是3*224*224(3是因为有RGB3个通道)
        trans = transforms.Compose((
            lambda x: Image.open(x).convert('RGB'),
            transforms.Resize((self.resize, self.resize)),  # 必须要把长宽都一起写上啊!!!
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        # 这个数据是根据resnet中的图片统计得到的,直接拿来用就好
        ))
        img = trans(img)
        label = torch.tensor(label)
        return img, label
def evaluate(model, loader):
    correct_num = 0
    total_num = len(loader.dataset)
    for img, label in loader:  # lodaer中包含了很多batch,每个batch有32张图片
        img, label = img.to(device), label.to(device)
        with torch.no_grad():
            logits = model(img)
            pre_label = logits.argmax(dim=1)
        correct_num += torch.eq(pre_label, label).sum().float().item()

    return correct_num / total_num

batch_size = 32
lr = 1e-3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1234)
# Step 1:准备数据集
#train_db = Pokemon('pokemon',224,'train') #将所有图片(顺序已打乱)的前60%作为train_set
#val_db = Pokemon('pokemon',224,'val')  #60%~80%作为validation_set
test_db = Pokemon('pokemon',224,'test') #80%~100%作为test_set
#train_loader = DataLoader(train_db,batch_size=batch_size,shuffle=True) #之后调用一次train_loader就会把train_db划分成很多batch
#val_loader = DataLoader(val_db,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(test_db,batch_size=batch_size,shuffle=True)
#Step 2: 初始化网络
model=newmodel.Model()

# Step 3:加载训练好的权重
model.load_state_dict(torch.load('best.pth'))
print('模型训练完毕,已将参数设置成训练过程中的最优值,现在开始测试test_set')
test_acc = evaluate(model, test_loader)
print('测试集上的准确率是:{}'.format(test_acc))

newmodel.py

import torch
import torch.nn as nn
from torchvision.models.resnet import resnet18
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.vit = ViT(image_size = 256,patch_size = 16,num_classes = 33,dim = 1024,depth = 6,heads = 16,mlp_dim = 2048,dropout = 0.1,emb_dropout = 0.1)
        #for param in self.resnet.parameters():
            #param.requires_grad = finetune
        #self.linear = nn.Linear(in_features=512, out_features=33)

    def forward(self, x):
        '''x = self.resnet.conv1(img)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)
        f1 = self.resnet.layer1(x)
        f2 = self.resnet.layer2(f1)
        f3 = self.resnet.layer3(f2)
        f4 = self.resnet.layer4(f3)
        result = self.pool(f4)
        batch_size = result.shape[0]
        result=result.reshape(batch_size,512)
        final = self.linear(result)
        '''
        final=self.vit(x)#(b,33)
        return final



# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

#PreNorm是对层进行归一化
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

#FeedForward就是两层线性变换
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

#attention的输入和输出维度相同[1,num_patches+1,128]-->[num_patches+1,128],其目的是赋予不同patch不同的权重;
#给予不同的注意力  dim表示输入的维度,dim_head表示进入qkv每个头的维度,head表示有多少个头
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads#表示一共的维度,dim_head是每个头的维度,heads是一共有多少个头
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)#dim=-1表示取维度最后一层
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    # 获得三个维度相同的向量q,k,v,然后q,k相乘获得权重,乘以scale,再经过softmax之后,乘到v上
    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)#将qkv一起的向量 分成三块分别代表qkv
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

#Transformer就是将降维后的patches叠加上不同的系数(注意力机制),再加上两层线性传输
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        # 这里是对块进行编码,将patch_height*patch_width的大小输出维度变成隐层dim
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):

        x = self.to_patch_embedding(img) # 对图像进行分块和降维编码
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)#加上一个分类维度[1,1,128]叠加到输入上面
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]#加上位置编码
        x = self.dropout(x)

        x = self.transformer(x)#经过注意力机制和两层线性变换

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]#对num_patches+1这个维度求均值,x的维度由[1,num_patches+1,128]-->[1,1,128]

        x = self.to_latent(x)#再经过一层线性变换输出维度到num_classes #128->num_classes
        print(x.shape)#(1,1024)
        solo_feature=self.mlp_head(x)
        print(solo_feature.shape)
        return solo_feature

你可能感兴趣的:(python,vit,python,vit,transformer)