代码写的特别好,很有借鉴意义。链接如下:
[github]: https://github.com/yjxiong/tsn-pytorch
class VideoRecord(object):
def __init__(self, row):
self._data = row
@property #@prproperty的用法 将一个类方法转变成一个类属性
#试图将该属性设为其他值,我们会引发一个AttributeError错误
def path(self):
return self._data[0]
@property
def num_frames(self):
return int(self._data[1])
@property
def label(self):
return int(self._data[2])
if average_duration > 0: # np.multiply 的使用 np.multiply([0, 1, 2], x) = [0, x, 2x]
offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, size=self.num_segments)
# 计算top1 top5正确率
# topk=(1,5)
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
if 'resnet' in base_model or 'vgg' in base_model:
#getattr() 函数用于返回一个对象属性值。
#torchvision.model pytorch 自带的复现模型
#这句话相当于 models.resnet101(pretrained=True)
self.base_model = getattr(torchvision.models, base_model)(True)
if self._enable_pbn:
print("Freezing BatchNorm2D except the first one.")
for m in self.base_model.modules():
if isinstance(m, nn.BatchNorm2d):
count += 1
if count >= (2 if self._enable_pbn else 1):
m.eval()
# shutdown update in frozen mode
m.weight.requires_grad = False
m.bias.requires_grad = False