手撕Pytorch源码#1.Dataset类 part1

写在前面

  1. 手撕Pytorch源码系列目的:

  • 通过手撕源码复习+了解高级python语法

  • 熟悉对pytorch框架的掌握

  • 在每一类完成源码分析后,会与常规深度学习训练脚本进行对照

  • 本系列预计先手撕python层源码,再进一步手撕c源码

  1. 版本信息

python:3.6.13

pytorch:1.10.2

  1. 本博文涉及python语法点

  • Generic,TypeVar泛型编程知识点

  • Type hint知识点

  • typing.Dict,typing.Callable知识点

目录

[TOC]

零、流程图

手撕Pytorch源码#1.Dataset类 part1_第1张图片

一、Dataset类

1.0 源代码
class Dataset(Generic[T_co]):
    
    functions: Dict[str, Callable] = {}

    def __getitem__(self, index) -> T_co:
        raise NotImplementedError

    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py
1.1 Generic[T_co]
  1. typing.Generic类是泛型的声明,与C++语言中的泛型类似,可以用于创建类模板【见2.1节 泛型编程】

1.2functions: Dict[str, Callable] = {}
  1. :操作符:起到type hint类型提示的作用【见2.2节 Type hint】

  1. typing.Dict:字典类型说明,与dict()作用不同【见2.3节 typing.Dict】

  1. typing.Callable:可调用对象类型说明【见2.4节 typing.Callable】

1.3 __getitem__(self, index) -> T_co
  1. __getitem__方法用于直接索引对象的属性值,如list对象l = [1,2,3]可以直接使用l[0]进行索引,即是调用了__getitem__方法,而在由Dataset生成DataLoader的过程中,需要使用__getitem__

  1. ->操作符:同样起到type hint类型提示的作用,:操作符往往提示变量,或输入参数的类型,而->则提示函数返回值的类型【见2.2节 Type hint】

  1. T_co:T_co是由typing.TypeVar()生成的变量,起到泛型声明的作用【见2.1节 泛型编程】

  1. raise NotImplementedError:Dataset类需要其父类(自己定义的数据集类)实现__getitem__方法,否则就会报错(同时需要实现的还有__init__魔法方法和__len__魔法方法)

1.4 __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]'
  1. __add__魔法方法用于直接利用+号对Dataset进行拼接,生成ConcatDataset[T_co]类对象

  1. ConcatDataset类源码见下一期博文

二、相应的python语法补充

2.1 泛型编程
  1. 相当于C++中的template,即定义类模板时不显示指定参数类比,而是在实际实例化时再进行指定

  1. T_co是由API:typing.TypeVar()实例化的对象,其定义如下:

T_co = TypeVar('T_co', covariant=True)
  • 其中'T_co'为其名称

  • covariant = True的作用如下:

from typing import TypeVar
class Figure3D():
    pass
class Cubo(Figure3D):
    pass
class CaixaDePapelao(Cubo):
    pass

T = TypeVar("T")
T_co = TypeVar("T_co",covariant = True)

class Renderizador(Generic[T_co]):
    def __init__(self,x:T_co)->None:
        pass

def exe_render(render:Renderizator[Cubo]):
    pass

render_1 = Renderizador(CaixaDePapelao())
exe_render(render_1)
  • 在exe_render()函数的定义中,type hint指明输入参数为Renderizator[Cubo],而实际函数的参数是Renderizador[CaixaDePapelao]与type hint不符,但CaixaDePapelao是Cubo的子类,因而为了使得子类CaixaDePapelao也能满足type hint,需要将其TypeVar()类的covariant参数置为True。此外,默认情况下TypeVar()类参数invariant = True,即子类与父类均不满足。covariant = True仅有子类满足,countervariant = True仅有父类满足

  • 常用的typing.TypeVar()函数的用法如下:

from typing import TypeVar,Union,List
# 利用Union类声明T_1为int列表与float列表类型
T_1 = TypeVar("T_1",bound = Union[List[int],List[float]]) 

# 直接传参声明T_2为int类型和str类型
T_2 = TypeVar("T_2",str,int)

# 不传参说明T可能是任意类型
T = TypeVar("T")
  1. 常常在定义类模板时继承typing.Generic[T_co]类,相当于C++中进行泛型声明template

2.2 Type hint
  1. 虽然python为轻类型语言,在编程时可不需要标注其类型,但在定义函数和类时,为了方便可视化,避免传参以及返回值错误,python引入了Type hint机制

  1. 对于函数的声明,常见的type hint使用如下:

def add(a:int,b:int)->int:
    return a+b
  • a:int,b:int说明a,b两参数为int类型,->int表示函数的返回值是int类型

2.3 typing.Dict
  1. typing.py中定义了Dict,List,Set,Callable等类,但其与函数list(),dict()等作用不同

  1. typing.Dict[]用于类型声明,不可以进行赋值和初始化,而dict()则可以用于进行初始化

# Dict用于类型的声明Type hint
functions: typing.Dict[str, Callable] = {}
# dict()用于初始化实例化一个字典
dic = dict(1="1",2="2")
2.4 typing.Callable
  1. Callable类是可调用的类型,函数与类都是可调用的类型,可以用isinstance函数进行验证

from typing import Callable
class Addtion():
    pass

def func():
    pass

a = Addition()
b = 10
isinstance(a,Callable) # False
isinstance(Addition.Callable) # True
isinstance(func,Callable) # True
isinstance(b,Callable) # False
  1. functions: Dict[str, Callable] = {}说明该字典的key值为字符串str类型,而value值为Callable类型,即函数或类

你可能感兴趣的:(手撕Pytorch代码,pytorch,python,深度学习,人工智能,神经网络)