实用程序类Accumulator

        前些日子发现动手学深度学习pytorch版李沐大佬是有课的,之前一直跟着另外一个GitHub项目在学,这里是对之前准确率中的一个实用程序类的解释,防止自己忘记。

        

class Accumulator:
    #在n个变量上累加
    def __init__(self, n):
        self.data = [0.0] * n
    def add(self, *args):
        for a, b in zip(self.data, args):
        self.data = [a + float(b) for a,b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

        首先在初始化的时候会根据传进来的n的大小来创建n个空间,且初始化全部为0.0。

        接着在使用.add()的时候情况下,虽然*args代表这里可以传入任意个参数,但是因为要和初始化的个数相同不然要报错。for a,b in zip(self.data,args)是把原来类中对应位置的data和新传入的args做 a + float(b)加法操作然后重新赋给该位置的data。从而达到累加器的累加效果。

        reset函数即重新设置空间大小并初始化。

        __getitem__实现类似数组的取操作。

下面是测试。

class Accumulator:
    #在n个变量上累加
    def __init__(self, n):
        self.data = [0.0] * n
    def add(self, *args):
        for a, b in zip(self.data, args):
            print(a, b)
        self.data = [a + float(b) for a,b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

metric = Accumulator(3)
print(metric[1])
metric.add(3,6,9)
print(metric[1])
metric.add(1,2)
print(metric[1])

输出:

实用程序类Accumulator_第1张图片

 

你可能感兴趣的:(《动手学深度学习》代码详解,深度学习,python,机器学习)