手撕Pytorch源码#2.Dataset类 part2

写在前面

  1. 手撕Pytorch源码系列目的:

  • 通过手撕源码复习+了解高级python语法

  • 熟悉对pytorch框架的掌握

  • 在每一类完成源码分析后,会与常规深度学习训练脚本进行对照

  • 本系列预计先手撕python层源码,再进一步手撕c源码

  1. 版本信息

python:3.6.13

pytorch:1.10.2

  1. 本博文涉及python语法点

  • @staticmethod修饰器

  • super类的全新理解【大概率有你闻所未闻的华点!】

  • bisect二分法搜索方法

  • @property修饰器

目录

[TOC]

零、流程图

手撕Pytorch源码#2.Dataset类 part2_第1张图片

一、ConcatDataset类

  1. 在上篇博文中,Dataset类的__add__方法中使用了ConcatDataset类,因而本篇对其进行研究,并学习相关的python语法点

    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])
  • Dataset[T_co],Type hint用法见上一篇博文

1.0 源代码
  • 注:在第一部分未进行精讲的代码已在下方源代码处做了详细的注释!

class ConcatDataset(Dataset[T_co]):
    datasets: List[Dataset[T_co]]
    cumulative_sizes: List[int]

    @staticmethod
    def cumsum(sequence):
        r, s = [], 0
        for e in sequence:
            l = len(e)
            r.append(l + s)
            s += l
        return r

    def __init__(self, datasets: Iterable[Dataset]) -> None:
        super(ConcatDataset, self).__init__()
        # 将Iterable[Dataset]转化成为List(Dataset)
        self.datasets = list(datasets)
        # 对传入的datasets序列的长度进行合法性判断
        assert len(self.datasets) > 0, 'datasets should not be an empty iterable'  # type: ignore[arg-type]
        # 仅接收Dataset类型,不接收IterableDataset类型
        for d in self.datasets:
            assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
        self.cumulative_sizes = self.cumsum(self.datasets)

    def __len__(self):
        # cumulative_size的最后一个元素就是整个数据集序列的总长度
        return self.cumulative_sizes[-1]

    def __getitem__(self, idx):
        # 考虑的是用负数进行索引的情况
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            # 对负数的索引值进行正向化,只需要用总长度+负数索引值即可
            # 例如:对最后一个元素的负数索引为-1,总长度为n,则n-1恰好为正向索引值
            # len(self)相当于调用__len__方法,获取总长度
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        # 如果判定就是第一个数据集,那么对数据集的索引就是总的索引
        if dataset_idx == 0:
            sample_idx = idx
        # 如果判定不是第一个数据集,那么对数据集的索引就不是总的索引
        # 例如,计算出来的cumulative_size为[100,300,600,1000],而总索引为569
        # 那么可以判定出来索引位于第三个数据集上(数据集大小为300)
        # 那么在第三个数据集上的索引应该是569-300=269
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return self.datasets[dataset_idx][sample_idx]

    @property
    def cummulative_sizes(self):
        warnings.warn("cummulative_sizes attribute is renamed to "
                      "cumulative_sizes", DeprecationWarning, stacklevel=2)
        return self.cumulative_sizes
1.1 @staticmethod以及cumsum()函数
  • @符号表示修饰器,而staticmethod则声明了该函数为类的静态方法,具体解释见【2.1节@staticmethod】

  • cumsum(sequence)用于计算每个数据集累计长度的序列,为了方便__getitem__方法中,通过索引的序列号index定位序列号位于哪个数据集

  • 假定传入的数据集序列为datasets = [ds1,ds2,ds3,ds4]共有四个数据集,其长度分别为[100,200,300,400],则cumcum(sequence)函数则会生成列表cumulative = [100,300,600,1000]。如果__getitem__方法传入的序列号为596,则可以用cumulative列表判断该序号的数据属于第三个数据集,进而可以从ds3中取得相应的数据

1.2 super(ConcatDataset, self).__init__()
  • 上述语句的作用为调用ConcatDataset类的父类Dataset的初始化函数

  • super类的精讲见【2.2节super类】,关于其中涉及的mro链的精讲见下一篇博文

1.3 dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
  • bisect_right函数的精讲见【2.3节bisect二分法搜索】

  • 源代码利用bisect_right通过二分法找到目标索引值index在前文所述cumsum(sequence)函数生成的列表中的位置,并且通过该位置即可判定该索引位于数据集序列的第几个数据集中,从而计算出dataset_idx

1.4 @properpy以及cummulative_sizes(self)
  1. @properpy用于是函数可以像属性一样被直接调用,精讲见【2.4节@property】

  1. cummulative_sizes(self)函数其实为了版本兼容,通过warning.warn提示用户此方法已经改名,并将正确的值通过函数传递给该方法

  1. warnings.warn(message, category=None, stacklevel=1, source=None)函数用于抛出警告,而上述代码中的DeprecationWarning是代码被弃用的警告

二、相应的Python语法补充

2.1 @staticmethod
  • @符号表示修饰器,而staticmethod则声明了该函数为类的静态方法,静态方法直接属于类,不需要传入self参数,正如外部函数一样进行定义,并且可以直接通过类或对象进行调用

class Static():
    def __init__(self,x:int)->None:
        self.x = x
    @staticmethod
    def static()->None:
        print("Trying static")

Static.static()
s = Static(1)
s.static()

# 输出
# Trying static
# Trying static
2.2 super类
  1. 冷知识:super()其实并不是一个函数,更不是一个关键字,而是一个类,因此在程序中使用super()是创建了一个对象,而super的原型则是super(type,type or object)

  1. super类传入的两个参数分别代表什么:

  • 第一个type为一个类名,如上代码中的ConcatDataset,而第二个object则往往是一个M object,决定一个mro chain【mro链精讲见下一篇博文】

  • 而第一个type决定了在mro chain中的取值位置,即从第一个type类下一个类开始,在第二个object决定的mro链上寻找最近的函数进行调用

  • 当然,光说很抽象,直接上代码

from objprint import op
class Animal():
    def __init__(self,age):
        self.age = age
    
class Person(Animal):
    def __init__(self, age,name):
        super(Person,self).__init__(age)
        self.name = name

class Male(Person):
    def __init__(self,age,name):
        super(Person,self).__init__(age)
        self.gender = 'male'

class Female(Person):
    def __init__(self,age,name):
        super(Female,self).__init__(age,name)
        self.gender = 'female'

m = Male(50,"Tim")
fm = Female(40,"Lily")

op(m)
op(fm)

# 输出结果:
# 
# 
  • 根据上述代码结果可以发现:Male类与Female中super()的第二个参数均为各自类别的self,因此Male类与Female类对应的mro链为Male:Male->Person->Animal,Female:Female->Person->Animal

  • Male:而在Male类中super()传入的第一个参数为Person因此从mro链中Person的下一个类,即Animal类开始查找__init__初始化函数,因此最终Male类的初始化函数相当于仅调用了Animal.__init__函数,因而无法对name属性进行初始化定义

  • Female:而Female类中super()传入的第一个参数为Female因此从mro链中Female的下一个类,即Person类开始查找__init__初始化函数,因此首先调用Person.__init__对name属性进行初始化

  • Female:接着,对于Person类,其super()的第二个参数为Person类比的self,因此其mro链为Person:Person->Animal,故从Animal类开始寻找__init__函数进行初始化。最终Female类的初始化函数相当于同时调用了Person.__init__函数以及Animal.__init__函数,因而对name和age属性都进行了初始化定义

  1. super类可能存在的黑魔法

  • 首先上代码看例子

class A():
    def display(self):
        print("A")

class B(A):
    def display(self):
        super(B,self).display()

class C(A):
    def display(self):
        print("C")

class D(B,C):
    def display(self):
        super(B,self).display()

b = B()
d = D()
b.display()
d.display()

# 输出结果为:
# A
# C
  • 或者对上述代码中的D类换一种写法

    class D(B,C):
        def display(self):
            B.display(self)

    d = D()
    d.display()
# 输出结果为:
# C
  • 那么为什么调用B.display函数最后会输出C,而C则和B没有任何的继承与被继承关系呢?

  • 首先,由于D继承了B和C两个类,因此其mro链为D:D->B->C->A,而super()类输入的第一个参数为B,即从B的下一个类开始寻找display()函数,进而调用C的display()函数最终输出了C

  • 因此,虽然B与C类没有任何的继承与被继承关系,但两类通过D类的mro链被联系在一起

2.3 bisect二分法搜索
  1. python的bisect库有三种二分法搜索的API分别是:bisect.bisect(Sequence,x),bisect.bisect_left(Sequence,x),bisect.bisect_right(Sequence,x)

  • 其中,在二分法搜索上bisect.bisect(Sequence,x)与bisect.bisect_right(Sequence,x)是完全等效的

  • bisect.bisect(Sequence,x)与bisect.bisect_right(Sequence,x)是找到传入值x的最小下标值

  • 而bisect.bisect_left(Sequence,x)则是找到传入值x的最小下标值

  1. 话不多说,直接上代码:

import bisect
l = [100,255,512,1036]
print(bisect.bisect(l,500))
print(bisect.bisect_right(l,500))
print(bisect.bisect_left(l,500))

print(bisect.bisect(l,512))
print(bisect.bisect_right(l,512))
print(bisect.bisect_left(l,512))

# 输出值为:
# 2
# 2
# 2
# 3
# 3
# 2
2.4 @property
  1. @property目的是可以让函数像属性一样被调用

  1. 而将属性变化成函数则是可以避免类调用时,对象的属性被非法修改

  1. 如一个长方形类,其属性为长,宽与面积,而如果将面积作为属性进行定义,那么在类外,可以直接利用self.area对面积值进行修改,从而使得长宽与面积值的不匹配,而用@property修饰器则可以解决该问题

  1. 代码示例如下:

class Property():
    def __init__(self,height,width) -> None:
        self.height = height
        self.width = width

    @property
    def area(self):
        return self.height*self.width


p = Property(2,4)
print(p.area)

# 输出结果为:
# 8

你可能感兴趣的:(手撕Pytorch代码,python,深度学习,人工智能,神经网络,pytorch)