[深度学习] labels2Dto3D 函数

labels2Dto3D 函数

函数代码:

def labels2Dto3D(labels, cell_size, add_dustbin=True):
    '''
    Change the shape of labels into 3D. Batch of labels.

    :param labels:
        tensor [batch_size, 1, H, W]
    :param cell_size:
        8
    :return:
         labels: tensors[batch_size, 65, Hc, Wc]
    '''
    batch_size, channel, H, W = labels.shape
    Hc, Wc = H // cell_size, W // cell_size
    space2depth = SpaceToDepth(8)
    labels = space2depth(labels)
    if add_dustbin:
        dustbin = labels.sum(dim=1)
        dustbin = 1 - dustbin
        dustbin[dustbin < 1.] = 0
        labels = torch.cat((labels, dustbin.view(batch_size, 1, Hc, Wc)), dim=1)
        ## norm
        dn = labels.sum(dim=1)
        labels = labels.div(torch.unsqueeze(dn, 1))
    return labels

逐句解析

batch_size, channel, H, W = labels.shape
Hc, Wc = H // cell_size, W // cell_size

batch_size=64,channel=1,H=120,W=160

Hc=15 Wc=20

space2depth = SpaceToDepth(8)
labels = space2depth(labels)

SpaceToDepth是一个类,创建了一个 SpaceToDepth对象,构造函数参数输入 8,进去看看:

https://blog.csdn.net/weixin_44179561/article/details/128058411?csdn_share_tail=%7B%22type%22%3A%22blog%22%2C%22rType%22%3A%22article%22%2C%22rId%22%3A%22128058411%22%2C%22source%22%3A%22weixin_44179561%22%7D

也就是说,(64,1,120,160)的labels 经过 space2depth 的处理,转换成了 (64,64,15,20)的labels

由于 add_dustbin = True 继续往下

dustbin = labels.sum(dim=1)
dustbin = 1 - dustbin
dustbin[dustbin < 1.] = 0
labels = torch.cat((labels, dustbin.view(batch_size, 1, Hc, Wc)), dim=1)

这几步需要结合起来理解,总的来说是在 labels 的64个通道后面接上一个表示无特征点的通道:

[深度学习] labels2Dto3D 函数_第1张图片

dn = labels.sum(dim=1)

dn:(64,15,20),65个通道的总和

labels = labels.div(torch.unsqueeze(dn, 1))

torch.unsqueeze(dn, 1):(64,1,15,20)

labels中每个通道都除以这个总和

labels:(64,65,15,20)

return labels 返回并 .float() 转换成浮点数,然后赋值给 lebels_3D

你可能感兴趣的:(深度学习,人工智能)