def labels2Dto3D(labels, cell_size, add_dustbin=True):
'''
Change the shape of labels into 3D. Batch of labels.
:param labels:
tensor [batch_size, 1, H, W]
:param cell_size:
8
:return:
labels: tensors[batch_size, 65, Hc, Wc]
'''
batch_size, channel, H, W = labels.shape
Hc, Wc = H // cell_size, W // cell_size
space2depth = SpaceToDepth(8)
labels = space2depth(labels)
if add_dustbin:
dustbin = labels.sum(dim=1)
dustbin = 1 - dustbin
dustbin[dustbin < 1.] = 0
labels = torch.cat((labels, dustbin.view(batch_size, 1, Hc, Wc)), dim=1)
## norm
dn = labels.sum(dim=1)
labels = labels.div(torch.unsqueeze(dn, 1))
return labels
batch_size, channel, H, W = labels.shape
Hc, Wc = H // cell_size, W // cell_size
batch_size=64,channel=1,H=120,W=160
Hc=15 Wc=20
space2depth = SpaceToDepth(8)
labels = space2depth(labels)
SpaceToDepth是一个类,创建了一个 SpaceToDepth对象,构造函数参数输入 8,进去看看:
https://blog.csdn.net/weixin_44179561/article/details/128058411?csdn_share_tail=%7B%22type%22%3A%22blog%22%2C%22rType%22%3A%22article%22%2C%22rId%22%3A%22128058411%22%2C%22source%22%3A%22weixin_44179561%22%7D
也就是说,(64,1,120,160)的labels 经过 space2depth 的处理,转换成了 (64,64,15,20)的labels
由于 add_dustbin = True 继续往下
dustbin = labels.sum(dim=1)
dustbin = 1 - dustbin
dustbin[dustbin < 1.] = 0
labels = torch.cat((labels, dustbin.view(batch_size, 1, Hc, Wc)), dim=1)
这几步需要结合起来理解,总的来说是在 labels 的64个通道后面接上一个表示无特征点的通道:
dn = labels.sum(dim=1)
dn:(64,15,20),65个通道的总和
labels = labels.div(torch.unsqueeze(dn, 1))
torch.unsqueeze(dn, 1):(64,1,15,20)
labels中每个通道都除以这个总和
labels:(64,65,15,20)
return labels 返回并 .float() 转换成浮点数,然后赋值给 lebels_3D