【Temporal Segment Networks】 pytorch代码总结

代码写的特别好,很有借鉴意义。链接如下:
[github]: https://github.com/yjxiong/tsn-pytorch

1. @property

class VideoRecord(object):
    def __init__(self, row):
        self._data = row
    @property  #@prproperty的用法 将一个类方法转变成一个类属性 
    #试图将该属性设为其他值,我们会引发一个AttributeError错误
    def path(self):
        return self._data[0]

    @property
    def num_frames(self):
        return int(self._data[1])

    @property   
    def label(self):
        return int(self._data[2])

2. np.multiply

if average_duration > 0: # np.multiply 的使用 np.multiply([0, 1, 2], x) = [0, x, 2x]
            offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, size=self.num_segments)

3. topk

# 计算top1 top5正确率
# topk=(1,5)
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

4. getattr & setattr

if 'resnet' in base_model or 'vgg' in base_model:
            #getattr() 函数用于返回一个对象属性值。
            #torchvision.model pytorch 自带的复现模型 
            #这句话相当于 models.resnet101(pretrained=True)
            self.base_model = getattr(torchvision.models, base_model)(True)

5. Freeze BN layers

if self._enable_pbn:
    print("Freezing BatchNorm2D except the first one.")
    for m in self.base_model.modules():
        if isinstance(m, nn.BatchNorm2d):
            count += 1
            if count >= (2 if self._enable_pbn else 1):
                m.eval()
                # shutdown update in frozen mode
                m.weight.requires_grad = False
                m.bias.requires_grad = False

你可能感兴趣的:(python,pytorch,计算机视觉,deep,learning,pytorch,temporal,segment,networks)