Module存储了模块类的函数
pytorch中模块非常容易使用,只需要派生自Module,重载两个函数就行了,那么Module都做了什么
class Module(object):
def __init__(self):
self._backend = thnn_backend
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._modules = OrderedDict()
self.training = True
构造函数生成一堆有序字典,用来存储各种参数,暂且不表,先说第一个结构self._backend是一个全局THNNFunctionBackend()类,存储一个一系列函数指针, 这个类派生类是FunctionBackend
class FunctionBackend(object):
def __init__(self):
self.function_classes = {}
def register_function(self, name, function_class):
self.function_classes[name] = function_class
其中这个类的function_classes字典的键是名称,值是函数,使用register_function添加注册,注册完毕后约有118个函数,本文的pytorch版本是0.4.1
RNN
RNNTanhCell
RNNReLUCell
LSTMCell
GRUCell
Dropout
Dropout2d
Dropout3d
MarginCriterion
MarginCriterionBackward
GatedLinear
GatedLinearBackward
SpatialFullConvolutionMap
SpatialFullConvolutionMapBackward
VolumetricFractionalMaxPooling
VolumetricFractionalMaxPoolingBackward
VolumetricFullDilatedConvolution
VolumetricFullDilatedConvolutionBackward
Col2Im
Col2ImBackward
DilatedConv2d
DilatedConv2dBackward
SpatialConvolutionLocal
SpatialConvolutionLocalBackward
FeatureLPPooling
FeatureLPPoolingBackward
VolumetricGridSamplerBilinear
VolumetricGridSamplerBilinearBackward
TemporalUpSamplingNearest
TemporalUpSamplingNearestBackward
SpatialUpSamplingNearest
SpatialUpSamplingNearestBackward
ReflectionPad1d
ReflectionPad1dBackward
SpatialConvolutionMap
SpatialConvolutionMapBackward
NLLLoss
NLLLossBackward
Softplus
SoftplusBackward
LogSigmoid
LogSigmoidBackward
SpatialUpSamplingBilinear
SpatialUpSamplingBilinearBackward
ReplicationPad3d
ReplicationPad3dBackward
MultiMarginLoss
MultiMarginLossBackward
ReplicationPad1d
ReplicationPad1dBackward
MultiLabelMarginLoss
MultiLabelMarginLossBackward
SpatialFullDilatedConvolution
SpatialFullDilatedConvolutionBackward
SoftMarginLoss
SoftMarginLossBackward
NLLLoss2d
NLLLoss2dBackward
MSELoss
MSELossBackward
Sigmoid
SigmoidBackward
VolumetricUpSamplingTrilinear
VolumetricUpSamplingTrilinearBackward
BCELoss
BCELossBackward
Square
SquareBackward
ReplicationPad2d
ReplicationPad2dBackward
L1Loss
L1LossBackward
SpatialGridSamplerBilinear
SpatialGridSamplerBilinearBackward
Sqrt
SqrtBackward
TemporalRowConvolution
TemporalRowConvolutionBackward
SpatialFractionalMaxPooling
SpatialFractionalMaxPoolingBackward
TemporalUpSamplingLinear
TemporalUpSamplingLinearBackward
VolumetricDilatedMaxPooling
VolumetricDilatedMaxPoolingBackward
Threshold
ThresholdBackward
Abs
AbsBackward
Softshrink
SoftshrinkBackward
LeakyReLU
LeakyReLUBackward
VolumetricUpSamplingNearest
VolumetricUpSamplingNearestBackward
VolumetricDilatedConvolution
VolumetricDilatedConvolutionBackward
Tanh
TanhBackward
TemporalSubSampling
TemporalSubSamplingBackward
ELU
ELUBackward
Hardtanh
HardtanhBackward
L1Cost
L1CostBackward
SpatialSubSampling
SpatialSubSamplingBackward
Im2Col
Im2ColBackward
KLDivLoss
KLDivLossBackward
SmoothL1Loss
SmoothL1LossBackward
ReflectionPad2d
ReflectionPad2dBackward
CrossMapLRN2d
EmbeddingBag
一不留神把pytorch支持的所有预定义模块都给展示出来了。本文稍后开始讲解这些预定义模块的实现。
其他有序字典
self._parameters = OrderedDict() # 模块网络参数
self._buffers = OrderedDict() # 驻留内存(不释放,不交换)
self._backward_hooks = OrderedDict() # 反向钩子函数字典,
self._forward_hooks = OrderedDict() # 正向钩子函数字典
self._forward_pre_hooks = OrderedDict() # 正向调用前钩子函数字典
self._modules = OrderedDict() # 模块列表
self.training = True # 训练还是验证
模块函数
模块的函数根据名称可以知道其作用,此处仅仅列举,不在详述
名称 | 作用 |
---|---|
forward | 前向计算虚函数 |
register_buffer | 注册驻留内存 |
register_parameter | 注册参数 |
add_module | 添加模块 |
_apply | 针对所有参数的操作 |
apply | 针对所有子模块的操作 |
cuda | 搬家到GPU上 |
cpu | 搬家到CPU上 |
type | 所有参数换类型喽 |
float | 统统换成浮点 |
double | 统统换成双精度浮点 |
half | 统统换成字(俩字节) |
to | 给用户一个换类型和CGPU的接口,其实还是调用_ |
register_backward_hook | 注册反向钩子 |
register_forward_pre_hook | 注册前向调用前钩子 |
register_forward_hook | 注册前向钩子 |
_slow_forward | 没有加速的前向函数 |
call | 给个参数就执行的前向调用 |
setstate | 快速设置所有字典状态 |
getattr | 获取属性 |
setattr | 设置属性 |
delattr | 删除属性 |
state_dict | 当前状态字典的输出 |
_load_from_state_dict | 从状态字典中装载的执行函数 |
load_state_dict | 装载状态的用户接口 |
children | 子模块 |
modules | 所有模块 |
train | 训练 |
eval | 评估 |
zero_grad | 参数梯度清零 |
share_memory | 使用共享内存 |
repr | 迭代器 |
dir | 列举 |