libtorch学习笔记(12)- 权重和偏置量的初始化

权重和偏置量在libtorch中各个神经网络模块中都有实现,所以一般来讲不需要特地的做初始化。

卷积网络conv2d1

Conv2dImpl::Conv2dImpl(
    Conv2dOptions options_)
    : ConvNdImpl(
        detail::ConvNdOptions<2>(
          /*in_channels=*/options_.in_channels(),
          /*out_channels=*/options_.out_channels(),
          /*kernel_size=*/options_.kernel_size())
          .stride(options_.stride())
          .padding(options_.padding())
          .dilation(options_.dilation())
          .transposed(false)
          .output_padding(0)
          .groups(options_.groups())
          .bias(options_.bias())
          .padding_mode(options_.padding_mode())) {}

/// Base class for all (dimension-specialized) convolution modules.
template <size_t D, typename Derived>
class ConvNdImpl : public torch::nn::Cloneable<Derived> {
 public:
  explicit ConvNdImpl(detail::ConvNdOptions<D> options_) : options(std::move(options_)) {
    reset();
  }

void reset() override {
	......
    reset_parameters();
}

void reset_parameters() {
    init::kaiming_uniform_(weight, /*a=*/std::sqrt(5));

    if (bias.defined()) {
        int64_t fan_in, fan_out;
        std::tie(fan_in, fan_out) = init::_calculate_fan_in_and_fan_out(weight);
        auto bound = 1 / std::sqrt(fan_in);
        init::uniform_(bias, -bound, bound);
    }
}

可以看出conv2d,权重会用凯明分布(MSRA),而偏置量如果存在的话,就用会用正态分布,这个正态分布会根据输入channel的数量和卷积核的size来产生,比如输入是channel是3,对应的kernel size是3x3,那么
f a n _ i n = 3 ∗ 3 ∗ 3 = 27 fan\_in = 3*3*3 = 27 fan_in=333=27
正态分布的区间就在
− 1 / f a n _ i n ) ∽ 1 / f a n _ i n ) − 0.19245 ∽ 0.19245 -1/\sqrt{fan\_in}) \backsim 1/\sqrt{fan\_in})\\ -0.19245 \backsim 0.19245 1/fan_in )1/fan_in )0.192450.19245

Linear模块

同样的在linear中也是使用同样的方法初始化权重和偏置量:

void LinearImpl::reset_parameters() {
  torch::nn::init::kaiming_uniform_(weight, std::sqrt(5)); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
  if (bias.defined()) {
    int64_t fan_in, fan_out;
    std::tie(fan_in, fan_out) =
      torch::nn::init::_calculate_fan_in_and_fan_out(weight);
    const auto bound = 1 / std::sqrt(fan_in);
    torch::nn::init::uniform_(bias, -bound, bound);
  }
}

BatchNorm模块

而batchnorm则是将权重参数全部置成1,偏置量设为0,参见如下代码:

void reset_parameters() {
  reset_running_stats();
  if (options.affine()) {
    torch::nn::init::ones_(weight);
    torch::nn::init::zeros_(bias);
  }
}

整个网络参数初始化

同一类型统一初始化

如果要用自己的方式来初始化权重参数和偏置量参数,可以在包含这些子模块的神经网络中,在创建和注册这些子模块后,用如下方式初始化:

for (auto& m : modules(false))
{
    if (m->name() == "torch::nn::Conv2dImpl")
    {
        printf("init the conv2d parameters.\n");
        auto spConv2d = std::dynamic_pointer_cast<torch::nn::Conv2dImpl>(m);
        spConv2d->reset_parameters();
        torch::nn::init::xavier_normal_(spConv2d->weight);

        //torch::nn::init::kaiming_normal_(spConv2d->weight, 0.0, torch::kFanOut, torch::kReLU);
        //torch::nn::init::constant_(spConv2d->weight, 1);
        if (spConv2d->options.bias())
            torch::nn::init::constant_(spConv2d->bias, 0);
    }
    else if (m->name() == "torch::nn::BatchNorm2dImpl")
    {
        printf("init the batchnorm2d parameters.\n");
        auto spBatchNorm2d = std::dynamic_pointer_cast<torch::nn::BatchNorm2dImpl>(m);
        torch::nn::init::constant_(spBatchNorm2d->weight, 1);
        torch::nn::init::constant_(spBatchNorm2d->bias, 0);
    }
    //else if (m->name() == "torch::nn::LinearImpl")
    //{
    //  auto spLinear = std::dynamic_pointer_cast(m);
    //  torch::nn::init::constant_(spLinear->weight, 1);
    //  torch::nn::init::constant_(spLinear->bias, 0);
    //}
}

按网络模块名称参数初始化

具体方法就是遍历所有的子模块,并且check它们是不是对应要改变权重的模块,如果要按照网络模块名称来改参数,可以用如下方式:

for (auto& m : named_modules(std::string(), false))
{
    if (m.key() == "C1")
    {
        printf("init the conv2d parameters.\n");
        auto spConv2d = std::dynamic_pointer_cast<torch::nn::Conv2dImpl>(m.value());
        spConv2d->reset_parameters();
        torch::nn::init::xavier_normal_(spConv2d->weight);

        //torch::nn::init::kaiming_normal_(spConv2d->weight, 0.0, torch::kFanOut, torch::kReLU);
        //torch::nn::init::constant_(spConv2d->weight, 1);
        if (spConv2d->options.bias())
            torch::nn::init::constant_(spConv2d->bias, 0);
    }
}

递归初始化

如果子模块下面还嵌套子模块,可能要写个递归函数继续往下check,然后进行修改,尤其是那些复杂的网络,比如ResNet,网络中还可能嵌套了Sequential

void VisitAllModules(std::shared_ptr<torch::nn::Module>& m)
{
    for (auto& sm : m->named_modules(std::string(), false))
    {
        if (sm.key() == "C1")
        {
            printf("init the conv2d parameters.\n");
            auto spConv2d = std::dynamic_pointer_cast<torch::nn::Conv2dImpl>(sm.value());
            spConv2d->reset_parameters();
            torch::nn::init::xavier_normal_(spConv2d->weight);

            //torch::nn::init::kaiming_normal_(spConv2d->weight, 0.0, torch::kFanOut, torch::kReLU);
            //torch::nn::init::constant_(spConv2d->weight, 1);
            if (spConv2d->options.bias())
                torch::nn::init::constant_(spConv2d->bias, 0);
        }

        VisitAllModules(sm.value());
    }
}

  1. 具体代码参考pytorch1.6的source code ↩︎

你可能感兴趣的:(C++,libtorch,笔记,pytorch,深度学习,机器学习,神经网络)