主要分为几块
1、数据集读取
2、hg-model
3、training
4、代码主要来自于github上几个 大佬的 代码的结合 @bearpaw 以及 @roytseng-tw 的训练代码和 @anibali 的evaluation代码, 主要这两位 大佬的代码 基本上和使用lua在torch7上的作者源代码没有什么出入,是很好的复现
5、同时采用了hourglass原作者的 训练集 验证集 测试集 @umich-vl
7、同时我也会在github上放出caffe版本的hourglass实现,这个主要来自于RMPE这个论文的github。
8、我目前训练结果在 MPII验证集上 只能达到 89.3 阈值0.5
一、数据读取
1、数据增广
这里主要涉及到了 crop 、scale 、flip 、rotate这几个操作
""" Random """
def randn():
return random.gauss(0, 1)
def rand():
return random.random()
def rnd(x):
'''umich hourglass mpii random function'''
return max(-2 * x, min(2 * x, randn() * x))
""" Visualization """
def show_sample(img, label): # FIXME: color blending is not right, diff color for each joint
nJoints = label.shape[0]
white = np.ones((4,) + img.shape[1:3])
new_img = white.copy()
new_img[:3] = img * 0.5
for i in range(nJoints):
new_img += 0.5 * white * sktf.resize(label[i], img.shape[1:3], preserve_range=True)
# print(label[i].max())
# plt.subplot(121)
# plt.imshow(np.transpose(new_img, [1, 2, 0]))
# plt.subplot(122)
# plt.imshow(label[i])
# plt.show()
return np.transpose(new_img, [1, 2, 0])
""" Label """
def create_label(imsize, pt, sigma, distro_type='Gaussian'):
label = np.zeros(imsize)
# Check that any part of the distro is in-bounds
ul = np.math.floor(pt[0] - 3 * sigma), np.math.floor(pt[1] - 3 * sigma)
br = np.math.floor(pt[0] + 3 * sigma), np.math.floor(pt[1] + 3 * sigma)
# If not, return the blank label
if ul[0] >= imsize[1] or ul[1] >= imsize[0] or br[0] < 0 or br[1] < 0:
return label
# Generate distro
size = 6 * sigma + 1
x = np.arange(0, size, 1, float)
y = x[:, np.newaxis]
x0 = y0 = size // 2
'''Note:
original torch impl: `local g = image.gaussian(size)`
equals to `gaussian(size, sigma=0.25*size)` here
'''
if distro_type == 'Gaussian':
distro = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
elif distro_type == 'Cauchy': # IS THIS CORRECT ???
distro = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)
# distro = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) * np.pi)
# Usable distro range
distro_x = max(0, -ul[0]), min(br[0], imsize[1]) - ul[0]
distro_y = max(0, -ul[1]), min(br[1], imsize[0]) - ul[1]
assert (distro_x[0] >= 0 and distro_y[0] >= 0), '{}, {}'.format(distro_x, distro_y)
# label range
label_x = max(0, ul[0]), min(br[0], imsize[1])
label_y = max(0, ul[1]), min(br[1], imsize[0])
label[label_y[0]:label_y[1], label_x[0]:label_x[1]] = \
distro[distro_y[0]:distro_y[1], distro_x[0]:distro_x[1]]
return label
""" Flip """
def fliplr_labels(labels, matchedParts, joint_dim=1, width_dim=3):
"""fliplr the joint labels, defaults (B, C, H, W)
"""
# flip horizontally
labels = np.flip(labels, axis=width_dim)
# Change left-right parts
perm = np.arange(labels.shape[joint_dim])
for i, j in matchedParts:
perm[i] = j
perm[j] = i
labels = np.take(labels, perm, axis=joint_dim)
return labels
def fliplr_coords(pts, width, matchedParts):
# Flip horizontally (only flip valid points)
pts = np.array([(width - x, y) if x > 0 else (x, y) for x, y in pts])
# Change left-right parts
perm = np.arange(pts.shape[0])
for i, j in matchedParts:
perm[i] = j
perm[j] = i
pts = pts[perm]
return pts
""" Transform, Crop """
def get_transform(center, scale, rot, res, invert=False):
'''Prepare transformation matrix (scale, rot).
'''
h = 200 * scale
t = np.eye(3) # transformation matrix
# scale
t[0, 0] = res[1] / h
t[1, 1] = res[0] / h
# translation
t[0, 2] = res[1] * (-center[0] / h + .5)
t[1, 2] = res[0] * (-center[1] / h + .5)
# rotation
if rot != 0:
rot = -rot # To match direction of rotation from cropping
rot_mat = np.zeros((3, 3))
rot_rad = rot * np.pi / 180
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
rot_mat[:2, :2] = [[cs, -sn],
[sn, cs]]
rot_mat[2, 2] = 1
# Need to make sure rotation is around center
t_mat = np.eye(3)
t_mat[0, 2] = -res[1] / 2
t_mat[1, 2] = -res[0] / 2
t_inv = t_mat.copy()
t_inv[:2, 2] *= -1
t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
if invert:
t = np.linalg.inv(t)
return t
def transform(pts, center, scale, rot, res, invert=False):
""" Transform points from original coord to new coord
pts: 2 * n array
"""
t = get_transform(center, scale, rot, [res, res], invert)
pts = np.array(pts)
assert pts.shape[0] == 2, pts.shape
if pts.ndim == 1:
pts = np.array([pts[0], pts[1], 1])
else:
pts = np.concatenate([pts, np.ones((1, pts.shape[1]))], axis=0)
new_pt = np.dot(t, pts)
return new_pt[:2].astype(int)
def crop(img, center, scale, rot, res):
'''
res: single value of targeted output image resolution
rot: in degrees
'''
# Preprocessing for efficient cropping
ht, wd = img.shape[0], img.shape[1]
# print(center, scale, rot, ht, wd)
sf = scale * 200.0 / res
# print(sf)
if sf < 2:
sf = 1
else:
new_size = int(np.math.floor(max(ht, wd) / sf))
new_ht = int(np.math.floor(ht / sf))
new_wd = int(np.math.floor(wd / sf))
if new_size < 2:
# Zoomed out so much that the image is now a single pixel or less
return np.zeros(res, res) if img.ndim == 2 \
else np.zeros(res, res, img.shape[2])
else:
img = sktf.resize(img, [new_ht, new_wd], preserve_range=True)
ht, wd = img.shape[0], img.shape[1]
# print(ht, wd)
# Calculate upper left and bottom right coordinates defining crop region
center = center / sf
scale = scale / sf
# print(center, scale)
ul = transform([0, 0], center, scale, 0, res, invert=True)
br = transform([res, res], center, scale, 0, res, invert=True)
if sf >= 2:
br += - (br - ul - res)
# print(ul, br)
# Padding so that when rotated proper amount of context is included
pad = np.math.ceil(np.linalg.norm(br - ul) / 2 - (br[0] - ul[0]) / 2)
# print(pad)
if rot != 0:
ul -= pad
br += pad
# print(ul, br)
# Define the range of pixels to take from the old image
old_x = max(0, ul[0]), min(br[0], wd)
old_y = max(0, ul[1]), min(br[1], ht)
# print(old_x, old_y)
# And where to put them in the new image
new_x = max(0, -ul[0]), min(br[0], wd) - ul[0]
new_y = max(0, -ul[1]), min(br[1], ht) - ul[1]
# print(new_x, new_y)
# Initialize new image and copy pixels over
new_shape = [br[1] - ul[1], br[0] - ul[0]]
# print(new_shape)
if len(img.shape) > 2:
new_shape += [img.shape[2]]
new_img = np.zeros(new_shape)
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
if rot != 0:
# Rotate the image and remove padded area
new_img = sktf.rotate(new_img, rot, preserve_range=True)
new_img = new_img[pad:-pad, pad:-pad]
if sf < 2:
new_img = sktf.resize(new_img, [res, res], preserve_range=True)
return new_img
2、针对数据集去读取数据batch
结合这个 脚本以及上面的数据增广脚本两个脚本基本上完成了全部的 数据操作。
class MPII_Dataset(torch.utils.data.Dataset):
def __init__(self, data_root, split,
inp_res=256, out_res=64, sigma=1,
scale_factor=0.25, rot_factor=30, return_meta=False, small_image=True):
self.data_root = data_root
self.split = split
self.inp_res = inp_res
self.out_res = out_res
self.sigma = sigma
self.scale_factor = scale_factor
self.rot_factor = rot_factor
self.return_meta = return_meta
self.small_image = small_image
self.nJoints = 16
self.accIdxs = [0, 1, 2, 3, 4, 5, 10, 11, 14, 15] # joint idxs for accuracy calculation
self.flipRef = [[0, 5], [1, 4], [2, 3], # noqa
[10, 15], [11, 14], [12, 13]]
self.annot = {}
tags = ['imgname', 'part', 'center', 'scale']
f = h5py.File('{}/mpii/{}.h5'.format(data_root, split), 'r')
for tag in tags:
self.annot[tag] = np.asarray(f[tag]).copy()
f.close()
def _getPartInfo(self, index):
# get a COPY
pts = self.annot['part'][index].copy()
c = self.annot['center'][index].copy()
s = self.annot['scale'][index].copy()
# Small adjustment so cropping is less likely to take feet out
c[1] = c[1] + 15 * s
s = s * 1.25
return pts, c, s
def _loadImage(self, index):
impath = os.path.join(self.data_root, 'mpii/images', self.annot['imgname'][index].decode('utf-8'))
im = skim.img_as_float(skio.imread(impath))
return im
def __getitem__(self, index):
im = self._loadImage(index)
pts, c, s = self._getPartInfo(index)
r = 0
if self.split == 'train':
# scale and rotation
s = s * (2 ** rnd(self.scale_factor))
r = 0 if rand() < 0.6 else rnd(self.rot_factor)
# flip LR
if rand() < 0.5:
im = im[:, ::-1, :]
pts = fliplr_coords(pts, width=im.shape[1], matchedParts=self.flipRef)
c[0] = im.shape[1] - c[0] # flip center point also
# Color jitter
im = np.clip(im * np.random.uniform(0.6, 1.4, size=3), 0, 1)
# Prepare image
im = crop(im, c, s, r, self.inp_res)
if im.ndim == 2:
im = np.tile(im, [1, 1, 3])
if self.small_image:
# small size image
im_s = sktf.resize(im, [self.out_res, self.out_res], preserve_range=True)
# (h, w, c) to (c, h, w)
im = np.transpose(im, [2, 0, 1])
if self.small_image:
im_s = np.transpose(im_s, [2, 0, 1])
# Prepare label
labels = np.zeros((self.nJoints, self.out_res, self.out_res))
new_pts = transform(pts.T, c, s, r, self.out_res).T
for i in range(self.nJoints):
if pts[i, 0] > 0:
labels[i] = create_label(
labels.shape[1:],
new_pts[i],
self.sigma)
ret_list = [im.astype(np.float32), labels.astype(np.float32)]
if self.small_image:
ret_list.append(im_s)
if self.return_meta:
meta = [pts, c, s, r]
ret_list.append(meta)
return tuple(ret_list)
def __len__(self):
return len(self.annot['imgname'])
二、模型代码
1、首先我们先去把 残差网络的基本模块定义一下
class HgResBlock(nn.Module):
''' Hourglass residual block '''
def __init__(self, inplanes, outplanes, stride=1):
super().__init__()
self.inplanes = inplanes
self.outplanes = outplanes
midplanes = outplanes // 2
self.bn1 = nn.BatchNorm2d(inplanes)
self.conv1 = nn.Conv2d(inplanes, midplanes, 1, stride) # bias=False
self.bn2 = nn.BatchNorm2d(midplanes)
self.conv2 = nn.Conv2d(midplanes, midplanes, 3, stride, 1)
self.bn3 = nn.BatchNorm2d(midplanes)
self.conv3 = nn.Conv2d(midplanes, outplanes, 1, stride) # bias=False
self.relu = nn.ReLU(inplace=True)
if inplanes != outplanes:
self.conv_skip = nn.Conv2d(inplanes, outplanes, 1, 1)
def forward(self, x):
residual = x
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)
if self.inplanes != self.outplanes:
residual = self.conv_skip(residual)
out += residual
return out
2、定义hourglass基本结构
class Hourglass(nn.Module):
def __init__(self, depth, nFeat, nModules, resBlock):
super().__init__()
self.depth = depth
self.nFeat = nFeat
self.nModules = nModules # num residual modules per location
self.resBlock = resBlock
self.hg = self._make_hour_glass()
self.downsample = nn.MaxPool2d(2, 2)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
def _make_hour_glass(self):
hg = []
for i in range(self.depth):
res = [self._make_residual(self.nModules) for _ in range(3)] # skip(upper branch); down_path, up_path(lower branch)
if i == (self.depth - 1):
res.append(self._make_residual(self.nModules)) # extra one for the middle
hg.append(nn.ModuleList(res))
return nn.ModuleList(hg)
def _make_residual(self, n):
return nn.Sequential(*[self.resBlock(self.nFeat, self.nFeat) for _ in range(n)])
def forward(self, x):
return self._hour_glass_forward(0, x)
def _hour_glass_forward(self, depth_id, x):
up1 = self.hg[depth_id][0](x)
low1 = self.downsample(x)
low1 = self.hg[depth_id][1](low1)
if depth_id == (self.depth - 1):
low2 = self.hg[depth_id][3](low1)
else:
low2 = self._hour_glass_forward(depth_id + 1, low1)
low3 = self.hg[depth_id][2](low2)
up2 = self.upsample(low3)
return up1 + up2
class HourglassNet(nn.Module):
'''Hourglass model from Newell et al ECCV 2016'''
def __init__(self, nStacks, nModules, nFeat, nClasses, resBlock=HgResBlock, inplanes=3):
super().__init__()
self.nStacks = nStacks
self.nModules = nModules
self.nFeat = nFeat
self.nClasses = nClasses
self.resBlock = resBlock
self.inplanes = inplanes
self._make_head()
hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
for i in range(nStacks):
hg.append(Hourglass(4, nFeat, nModules, resBlock))
res.append(self._make_residual(nModules))
fc.append(self._make_fc(nFeat, nFeat))
score.append(nn.Conv2d(nFeat, nClasses, 1))
if i < (nStacks - 1):
fc_.append(nn.Conv2d(nFeat, nFeat, 1))
score_.append(nn.Conv2d(nClasses, nFeat, 1))
self.hg = nn.ModuleList(hg)
self.res = nn.ModuleList(res)
self.fc = nn.ModuleList(fc)
self.score = nn.ModuleList(score)
self.fc_ = nn.ModuleList(fc_)
self.score_ = nn.ModuleList(score_)
def _make_head(self):
self.conv1 = nn.Conv2d(self.inplanes, 64, 7, 2, 3)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.res1 = self.resBlock(64, 128)
self.pool = nn.MaxPool2d(2, 2)
self.res2 = self.resBlock(128, 128)
self.res3 = self.resBlock(128, self.nFeat)
def _make_residual(self, n):
return nn.Sequential(*[self.resBlock(self.nFeat, self.nFeat) for _ in range(n)])
def _make_fc(self, inplanes, outplanes):
return nn.Sequential(
nn.Conv2d(inplanes, outplanes, 1),
nn.BatchNorm2d(outplanes),
nn.ReLU(True))
def forward(self, x):
# head
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.res1(x)
x = self.pool(x)
x = self.res2(x)
x = self.res3(x)
out = []
for i in range(self.nStacks):
y = self.hg[i](x)
y = self.res[i](y)
y = self.fc[i](y)
score = self.score[i](y)
out.append(score)
if i < (self.nStacks - 1):
fc_ = self.fc_[i](y)
score_ = self.score_[i](score)
x = x + fc_ + score_
return out
三、训练
初始化数据 和 网络
train_set = MPII_Dataset(
FLAGS.dataDir, split='train',
inp_res=FLAGS.inputRes, out_res=FLAGS.outputRes,
scale_factor=FLAGS.scale, rot_factor=FLAGS.rotate, sigma=FLAGS.hmSigma)
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=FLAGS.trainBatch, shuffle=True,
num_workers=FLAGS.nThreads, pin_memory=True)
netHg = nn.DataParallel(HourglassNet(
nStacks=FLAGS.nStacks, nModules=FLAGS.nModules, nFeat=FLAGS.nFeats,
nClasses=train_set.nJoints)) # ref `nClasses` from dataset
criterion = nn.MSELoss()
if cuda:
torch.backends.cudnn.benchmark = True
netHg.cuda()
criterion.cuda()
optimHg = torch.optim.RMSprop(
netHg.parameters(),
lr=FLAGS.lr,
alpha=FLAGS.alpha, eps=FLAGS.eps)
调用网络进行训练
def run(epoch, iter_start=0):
netHg.train()
global global_step
pbar = tqdm.tqdm(train_loader, desc='Epoch %02d' % epoch, dynamic_ncols=True)
pbar_info = tqdm.tqdm(bar_format='{bar}{postfix}')
avg_acc = 0
for it, sample in enumerate(pbar, start=iter_start):
global_step += 1
image, label, image_s = sample
image = Variable(image)
label = Variable(label)
image_s = Variable(image_s)
if FLAGS.cuda:
image = image.cuda(async=True) # TODO: check the affect of async
label = label.cuda(async=True)
image_s = image_s.cuda(async=True)
# generator
outputs = netHg(image)
loss_hg_content = 0
for out in outputs: # TODO: speed up with multiprocessing map?
loss_hg_content += criterion(out, label)
loss_hg = loss_hg_content
optimHg.zero_grad()
loss_hg.backward()
optimHg.step()
accs = accuracy(outputs[-1].data.cpu(), label.data.cpu(), train_set.accIdxs)
sumWriter.add_scalar('loss_hg', loss_hg, global_step)
sumWriter.add_scalar('acc', accs[0], global_step)
# TODO: learning rate scheduling
# sumWriter.add_scalar('lr', lr, global_step)
pbar_info.set_postfix({
'loss_hg': getValue(loss_hg),
'acc': accs[0]
})
pbar_info.update()
avg_acc += accs[0] / len(train_loader)
pbar_info.set_postfix_str('avg_acc: {}'.format(avg_acc))
pbar.close()
pbar_info.close()