首先要声明的是Center Loss主要用于做人脸的识别,那么为什么不用softmax对人脸直接分类呢?因为人脸之间的特征是十分相似的,在类与类之间的交界处是很难区分开的,换句话说softmax分类两张人脸得到的概率值都是0.5左右,导致分类结果不准确。那么怎么把这种交界处区分开呢?有两种方法:1、第一种扩大类间距,2、第二种扩大类内距。CenterLoss就是采用的第二种方法。
下面是我用mnist数字十分类做的直接用softmax loss和softmax loss + center loss 做的效果图:
由于我做的效果好一些的图不小心删掉了,这个图大概能看出softmax loss损失的缺点,就是越靠近中心位置softmax是很难区分开的。
上面是Center Loss + softmax loss做出的效果图,很明显所有的类都清晰可分。
如上图,直接看公式,一个是softmax loss 一个是center loss把他们加起来就是总损失。
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
import torchvision
import torch.utils.data as data
import matplotlib.pyplot as plt
import torch.optim.lr_scheduler as lr_scheduler
import torchvision.transforms as transforms
class Centerloss(torch.nn.Module):
def __init__(self, feature_num, cls_num):
super(Centerloss, self).__init__()
self.cls_num = cls_num
self.center = nn.Parameter(torch.randn(cls_num, feature_num))
def forward(self, xs, ys):
# xs = F.normalize(xs)
# 把标签值当作中心点的中心点索引值 取出符合标签的所有中心点的坐标值
center_exp = self.center.index_select(dim=0, index=ys.long())
# 统计出标签0-cls_mun的每个值的个数
count = torch.histc(ys, bins=self.cls_num, min=0, max=self.cls_num-1)
# 把标签值当作count的索引值,取出符合标签的所有count的坐标值
count_dis = count.index_select(dim = 0, index=ys.long())
return torch.sum(torch.sqrt(torch.sum((xs - center_exp) ** 2, dim=1)) / count_dis.float())
class ClsNet(nn.Module):
def __init__(self):
super().__init__()
self.conv_layer = nn.Sequential(nn.Conv2d(1, 32, 3), nn.BatchNorm2d(32), nn.PReLU(),
nn.Conv2d(32, 64, 3), nn.BatchNorm2d(64), nn.PReLU(),
nn.MaxPool2d(3, 2))
self.feature_layer = nn.Sequential(nn.Linear(11*11*64, 256), nn.BatchNorm1d(256), nn.PReLU(),
nn.Linear(256, 128), nn.BatchNorm1d(128), nn.PReLU(),
nn.Linear(128, 2), nn.PReLU())
self.out_layer = nn.Sequential(nn.Linear(2, 10))
self.loss_fn1 = Centerloss(2, 10)
# self.loss_fn2 = nn.CrossEntropyLoss()
self.loss_fn2 = nn.CrossEntropyLoss()
def forward(self, x):
conv = self.conv_layer(x)
conv = conv.reshape(x.size(0), -1)
self.feature = self.feature_layer(conv)
self.out = self.out_layer(self.feature)
return self.feature
def get_loss(self, ys):
loss1 = self.loss_fn1(self.feature, ys)
loss2 = self.loss_fn2(self.out, ys.long())
return loss1, loss2
if __name__ == '__main__':
train_data = torchvision.datasets.MNIST(
root='mnist',
train=True,
transform=torchvision.transforms.ToTensor(),
download=True
)
test_data = torchvision.datasets.MNIST(
root='mnist',
train=False,
transform = torchvision.transforms.ToTensor(),
download=False
)
train = data.DataLoader(dataset=train_data, batch_size=1024, shuffle=True, drop_last= True)
test = data.DataLoader(dataset=test_data, batch_size=1024, shuffle=True)
# transform = transforms.Compose([
# transforms.Resize(28, 28),
# transforms.ToTensor(),
# transforms.Normalize((0.5,), (0.5,)),
net = ClsNet().cuda()
# net = net.to(device)
path = r'params/weightnet2.pt'
if os.path.exists(path):
net.load_state_dict(torch.load(path))
net.eval()
print('load susseful')
else:
print('load fail')
epoch = 1024
# optimism = optim.SGD(net.parameters(), lr=1e-3)
optimism = optim.Adam(net.parameters(), lr=0.0005)
# scheduler = lr_scheduler.StepLR(optimism, 10, gamma=0.8)
# optimizer = optim.SGD(net.parameters(), weight_decay=0.0005, lr=0.001, momentum=0.9)
# scheduler = lr_scheduler.StepLR(optimizer, 20, gamma=0.8)
# optimizercenter = optim.SGD(Centerloss.parameters(), lr=0.5)
losses = []
# In[]
c = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff',
'#ff00ff', '#990000', '#999900', '#009900', '#009999']
epoch = 10000
d = 0
# fig, ax = plt.subplots()
for i in range(epoch):
# scheduler.step()
print('epoch: {}'.format(i))
print(len(train))
tar = []
out = []
for j, (input, target) in enumerate(train):
input = input.cuda()
target = target.cuda()
output = net(input)
loss1, loss2 = net.get_loss(target)
loss = 0.1 * loss1 + loss2
# label = torch.argmax(output, dim=1) # 选出最大值的索引作为标签
# 清空梯度 反向传播 更新梯度
optimism.zero_grad()
loss.backward()
optimism.step()
output = output.cpu().detach().numpy()
# print(output)
target = target.cpu().detach()
print(target)
out.extend(output)
tar.extend(target)
print('[epochs - {} - {} / {}] loss: {} loss1:{} loss2: {}'.format(
i, j, len(train), loss.float(), loss1.float(), loss2.float()))
outstack = np.stack(out)
tarstack = torch.stack(tar)
# plt.cla()
plt.ion()
if j == 3:
d += 1
for m in range(10):
index = torch.tensor(torch.nonzero(tarstack == m))
plt.scatter(outstack[:, 0][index[:, 0]], outstack[:, 1][index[:, 0]], c=c[m], marker='.')
plt.show()
plt.pause(10)
plt.savefig('picture1.1/{0}.jpg'.format(d))
print('save sussece')
# plt.ioff()
# plt.clf()
plt.close()
torch.save(net.state_dict(), r'params/weightnet2.pt')