6 Module -庖丁解牛之pytorch

Module存储了模块类的函数

pytorch中模块非常容易使用,只需要派生自Module,重载两个函数就行了,那么Module都做了什么

class Module(object):
  def __init__(self):
        self._backend = thnn_backend
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._modules = OrderedDict()
        self.training = True

构造函数生成一堆有序字典,用来存储各种参数,暂且不表,先说第一个结构self._backend是一个全局THNNFunctionBackend()类,存储一个一系列函数指针, 这个类派生类是FunctionBackend

class FunctionBackend(object):
    def __init__(self):
        self.function_classes = {}
    def register_function(self, name, function_class):
        self.function_classes[name] = function_class

其中这个类的function_classes字典的键是名称,值是函数,使用register_function添加注册,注册完毕后约有118个函数,本文的pytorch版本是0.4.1

RNN                                      
RNNTanhCell                              
RNNReLUCell                              
LSTMCell                                 
GRUCell                                  
Dropout                                  
Dropout2d                                
Dropout3d                                
MarginCriterion                          
MarginCriterionBackward                  
GatedLinear                              
GatedLinearBackward                      
SpatialFullConvolutionMap                
SpatialFullConvolutionMapBackward        
VolumetricFractionalMaxPooling           
VolumetricFractionalMaxPoolingBackward   
VolumetricFullDilatedConvolution         
VolumetricFullDilatedConvolutionBackward 
Col2Im                                   
Col2ImBackward                           
DilatedConv2d                            
DilatedConv2dBackward                    
SpatialConvolutionLocal                  
SpatialConvolutionLocalBackward          
FeatureLPPooling                         
FeatureLPPoolingBackward                 
VolumetricGridSamplerBilinear            
VolumetricGridSamplerBilinearBackward    
TemporalUpSamplingNearest                
TemporalUpSamplingNearestBackward        
SpatialUpSamplingNearest                 
SpatialUpSamplingNearestBackward         
ReflectionPad1d                          
ReflectionPad1dBackward                  
SpatialConvolutionMap                    
SpatialConvolutionMapBackward            
NLLLoss                                  
NLLLossBackward                          
Softplus                                 
SoftplusBackward                         
LogSigmoid                               
LogSigmoidBackward                       
SpatialUpSamplingBilinear                
SpatialUpSamplingBilinearBackward        
ReplicationPad3d                         
ReplicationPad3dBackward                 
MultiMarginLoss                          
MultiMarginLossBackward                  
ReplicationPad1d                         
ReplicationPad1dBackward                 
MultiLabelMarginLoss                     
MultiLabelMarginLossBackward             
SpatialFullDilatedConvolution            
SpatialFullDilatedConvolutionBackward    
SoftMarginLoss                           
SoftMarginLossBackward                   
NLLLoss2d                                
NLLLoss2dBackward                        
MSELoss                                  
MSELossBackward                          
Sigmoid                                  
SigmoidBackward                          
VolumetricUpSamplingTrilinear            
VolumetricUpSamplingTrilinearBackward    
BCELoss                                  
BCELossBackward                          
Square                                   
SquareBackward                           
ReplicationPad2d                         
ReplicationPad2dBackward                 
L1Loss                                   
L1LossBackward                           
SpatialGridSamplerBilinear               
SpatialGridSamplerBilinearBackward       
Sqrt                                     
SqrtBackward                             
TemporalRowConvolution                   
TemporalRowConvolutionBackward           
SpatialFractionalMaxPooling              
SpatialFractionalMaxPoolingBackward      
TemporalUpSamplingLinear                 
TemporalUpSamplingLinearBackward         
VolumetricDilatedMaxPooling              
VolumetricDilatedMaxPoolingBackward      
Threshold                                
ThresholdBackward                        
Abs                                      
AbsBackward                              
Softshrink                               
SoftshrinkBackward                       
LeakyReLU                                
LeakyReLUBackward                        
VolumetricUpSamplingNearest              
VolumetricUpSamplingNearestBackward      
VolumetricDilatedConvolution             
VolumetricDilatedConvolutionBackward     
Tanh                                     
TanhBackward                             
TemporalSubSampling                      
TemporalSubSamplingBackward              
ELU                                      
ELUBackward                              
Hardtanh                                 
HardtanhBackward                         
L1Cost                                   
L1CostBackward                           
SpatialSubSampling                       
SpatialSubSamplingBackward               
Im2Col                                   
Im2ColBackward                           
KLDivLoss                                
KLDivLossBackward                        
SmoothL1Loss                             
SmoothL1LossBackward                     
ReflectionPad2d                          
ReflectionPad2dBackward                  
CrossMapLRN2d                            
EmbeddingBag                             

一不留神把pytorch支持的所有预定义模块都给展示出来了。本文稍后开始讲解这些预定义模块的实现。

其他有序字典

        self._parameters = OrderedDict() # 模块网络参数
        self._buffers = OrderedDict()       # 驻留内存(不释放,不交换)
        self._backward_hooks = OrderedDict() # 反向钩子函数字典,
        self._forward_hooks = OrderedDict() # 正向钩子函数字典
        self._forward_pre_hooks = OrderedDict() # 正向调用前钩子函数字典
        self._modules = OrderedDict() # 模块列表
        self.training = True # 训练还是验证

模块函数

模块的函数根据名称可以知道其作用,此处仅仅列举,不在详述

名称 作用
forward 前向计算虚函数
register_buffer 注册驻留内存
register_parameter 注册参数
add_module 添加模块
_apply 针对所有参数的操作
apply 针对所有子模块的操作
cuda 搬家到GPU上
cpu 搬家到CPU上
type 所有参数换类型喽
float 统统换成浮点
double 统统换成双精度浮点
half 统统换成字(俩字节)
to 给用户一个换类型和CGPU的接口,其实还是调用_
register_backward_hook 注册反向钩子
register_forward_pre_hook 注册前向调用前钩子
register_forward_hook 注册前向钩子
_slow_forward 没有加速的前向函数
call 给个参数就执行的前向调用
setstate 快速设置所有字典状态
getattr 获取属性
setattr 设置属性
delattr 删除属性
state_dict 当前状态字典的输出
_load_from_state_dict 从状态字典中装载的执行函数
load_state_dict 装载状态的用户接口
children 子模块
modules 所有模块
train 训练
eval 评估
zero_grad 参数梯度清零
share_memory 使用共享内存
repr 迭代器
dir 列举

你可能感兴趣的:(6 Module -庖丁解牛之pytorch)