train.py
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from torch import optim
import os
import csv
from PIL import Image
import warnings
warnings.simplefilter('ignore')
from torchvision.models import resnet18
import glob
import random
import newmodel
class Pokemon(Dataset):
def __init__(self, root, resize, mode):
super(Pokemon, self).__init__()
self.root = root
self.resize = resize
self.name2label = {}
for name in sorted(os.listdir(os.path.join(self.root))):
if not os.path.isdir(os.path.join(self.root, name)):
continue
self.name2label[name] = len(self.name2label.keys())
print(self.name2label)
self.images, self.labels = self.load_csv('images&labels.csv')
if mode == 'train':
self.images = self.images[:int(0.6 * len(self.images))]
self.labels = self.labels[:int(0.6 * len(self.labels))]
elif mode == 'val':
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else:
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.labels)):]
def load_csv(self, filename):
if not os.path.exists(os.path.join(self.root, filename)):
images = []
for name in self.name2label.keys():
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
print(len(images))
random.shuffle(images)
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for img in images:
name = img.split(os.sep)[-2]
label = self.name2label[name]
writer.writerow([img, label])
print('image paths and labels have been writen into csv file:', filename)
images, labels = [], []
with open(os.path.join(self.root, filename)) as f:
reader = csv.reader(f)
for row in reader:
img, label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images, labels
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img, label = self.images[idx], self.labels[idx]
trans = transforms.Compose((
lambda x: Image.open(x).convert('RGB'),
transforms.Resize((self.resize, self.resize)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
))
img = trans(img)
label = torch.tensor(label)
return img, label
batch_size = 32
lr = 1e-3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1234)
train_db = Pokemon('pokemon',224,'train')
val_db = Pokemon('pokemon',224,'val')
test_db = Pokemon('pokemon',224,'test')
train_loader = DataLoader(train_db,batch_size=batch_size,shuffle=True)
val_loader = DataLoader(val_db,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(test_db,batch_size=batch_size,shuffle=True)
'''
class Flatten(nn.Module):
def __init__(self):
super(Flatten,self).__init__()
def forward(self,x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.reshape(-1,shape)
'''
model=newmodel.Model()
print('模型需要训练的参数共有{}个'.format(sum(map(lambda p:p.numel(),model.parameters()))))
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=lr)
def evaluate(model, loader):
correct_num = 0
total_num = len(loader.dataset)
for img, label in loader:
img, label = img.to(device), label.to(device)
with torch.no_grad():
logits = model(img)
pre_label = logits.argmax(dim=1)
correct_num += torch.eq(pre_label, label).sum().float().item()
return correct_num / total_num
best_epoch, best_acc = 0, 0
for epoch in range(200):
for batch_num, (img, label) in enumerate(train_loader):
img, label = img.to(device), label.to(device)
logits = model(img)
loss = loss_fn(logits, label)
if batch_num % 5 == 0:
print('这是第{}次迭代的第{}个batch,loss是{}'.format(epoch + 1, batch_num + 1, loss.item()))
optimizer.zero_grad()
loss.backward()
optimizer.step()
val_acc = evaluate(model, val_loader)
if val_acc > best_acc:
print('验证集上的准确率是:{}'.format(val_acc))
best_epoch = epoch
best_acc = val_acc
torch.save(model.state_dict(), 'best.pth')
print('best_acc:{},best_epoch:{}'.format(best_acc, best_epoch))
'''
model.load_state_dict(torch.load('best.pth'))
# 开始检验
print('模型训练完毕,已将参数设置成训练过程中的最优值,现在开始测试test_set')
test_acc = evaluate(model, test_loader)
print('测试集上的准确率是:{}'.format(test_acc))
'''
test.py
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from torch import optim
import os
import csv
from PIL import Image
import warnings
warnings.simplefilter('ignore')
from torchvision.models import resnet18
import glob
import random
import newmodel
class Pokemon(Dataset):
def __init__(self, root, resize, mode):
super(Pokemon, self).__init__()
self.root = root
self.resize = resize
self.name2label = {}
for name in sorted(os.listdir(os.path.join(self.root))):
if not os.path.isdir(os.path.join(self.root, name)):
continue
self.name2label[name] = len(self.name2label.keys())
print(self.name2label)
self.images, self.labels = self.load_csv('images&labels.csv')
if mode == 'train':
self.images = self.images[:int(0.6 * len(self.images))]
self.labels = self.labels[:int(0.6 * len(self.labels))]
elif mode == 'val':
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else:
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.labels)):]
def load_csv(self, filename):
if not os.path.exists(os.path.join(self.root, filename)):
images = []
for name in self.name2label.keys():
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
print(len(images))
random.shuffle(images)
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for img in images:
name = img.split(os.sep)[-2]
label = self.name2label[name]
writer.writerow([img, label])
print('image paths and labels have been writen into csv file:', filename)
images, labels = [], []
with open(os.path.join(self.root, filename)) as f:
reader = csv.reader(f)
for row in reader:
img, label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images, labels
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img, label = self.images[idx], self.labels[idx]
trans = transforms.Compose((
lambda x: Image.open(x).convert('RGB'),
transforms.Resize((self.resize, self.resize)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
))
img = trans(img)
label = torch.tensor(label)
return img, label
def evaluate(model, loader):
correct_num = 0
total_num = len(loader.dataset)
for img, label in loader:
img, label = img.to(device), label.to(device)
with torch.no_grad():
logits = model(img)
pre_label = logits.argmax(dim=1)
correct_num += torch.eq(pre_label, label).sum().float().item()
return correct_num / total_num
batch_size = 32
lr = 1e-3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1234)
test_db = Pokemon('pokemon',224,'test')
test_loader = DataLoader(test_db,batch_size=batch_size,shuffle=True)
model=newmodel.Model()
model.load_state_dict(torch.load('best.pth'))
print('模型训练完毕,已将参数设置成训练过程中的最优值,现在开始测试test_set')
test_acc = evaluate(model, test_loader)
print('测试集上的准确率是:{}'.format(test_acc))
newmodel.py
import torch
import torch.nn as nn
from torchvision.models.resnet import resnet18
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.vit = ViT(image_size = 256,patch_size = 16,num_classes = 33,dim = 1024,depth = 6,heads = 16,mlp_dim = 2048,dropout = 0.1,emb_dropout = 0.1)
def forward(self, x):
'''x = self.resnet.conv1(img)
x = self.resnet.bn1(x)
x = self.resnet.relu(x)
x = self.resnet.maxpool(x)
f1 = self.resnet.layer1(x)
f2 = self.resnet.layer2(f1)
f3 = self.resnet.layer3(f2)
f4 = self.resnet.layer4(f3)
result = self.pool(f4)
batch_size = result.shape[0]
result=result.reshape(batch_size,512)
final = self.linear(result)
'''
final=self.vit(x)
return final
def pair(t):
return t if isinstance(t, tuple) else (t, t)
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
print(x.shape)
solo_feature=self.mlp_head(x)
print(solo_feature.shape)
return solo_feature