权重和偏置量在libtorch中各个神经网络模块中都有实现,所以一般来讲不需要特地的做初始化。
Conv2dImpl::Conv2dImpl(
Conv2dOptions options_)
: ConvNdImpl(
detail::ConvNdOptions<2>(
/*in_channels=*/options_.in_channels(),
/*out_channels=*/options_.out_channels(),
/*kernel_size=*/options_.kernel_size())
.stride(options_.stride())
.padding(options_.padding())
.dilation(options_.dilation())
.transposed(false)
.output_padding(0)
.groups(options_.groups())
.bias(options_.bias())
.padding_mode(options_.padding_mode())) {}
/// Base class for all (dimension-specialized) convolution modules.
template <size_t D, typename Derived>
class ConvNdImpl : public torch::nn::Cloneable<Derived> {
public:
explicit ConvNdImpl(detail::ConvNdOptions<D> options_) : options(std::move(options_)) {
reset();
}
void reset() override {
......
reset_parameters();
}
void reset_parameters() {
init::kaiming_uniform_(weight, /*a=*/std::sqrt(5));
if (bias.defined()) {
int64_t fan_in, fan_out;
std::tie(fan_in, fan_out) = init::_calculate_fan_in_and_fan_out(weight);
auto bound = 1 / std::sqrt(fan_in);
init::uniform_(bias, -bound, bound);
}
}
可以看出conv2d,权重会用凯明分布(MSRA),而偏置量如果存在的话,就用会用正态分布,这个正态分布会根据输入channel的数量和卷积核的size来产生,比如输入是channel是3,对应的kernel size是3x3,那么
f a n _ i n = 3 ∗ 3 ∗ 3 = 27 fan\_in = 3*3*3 = 27 fan_in=3∗3∗3=27
正态分布的区间就在
− 1 / f a n _ i n ) ∽ 1 / f a n _ i n ) − 0.19245 ∽ 0.19245 -1/\sqrt{fan\_in}) \backsim 1/\sqrt{fan\_in})\\ -0.19245 \backsim 0.19245 −1/fan_in)∽1/fan_in)−0.19245∽0.19245
同样的在linear中也是使用同样的方法初始化权重和偏置量:
void LinearImpl::reset_parameters() {
torch::nn::init::kaiming_uniform_(weight, std::sqrt(5)); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
if (bias.defined()) {
int64_t fan_in, fan_out;
std::tie(fan_in, fan_out) =
torch::nn::init::_calculate_fan_in_and_fan_out(weight);
const auto bound = 1 / std::sqrt(fan_in);
torch::nn::init::uniform_(bias, -bound, bound);
}
}
而batchnorm则是将权重参数全部置成1,偏置量设为0,参见如下代码:
void reset_parameters() {
reset_running_stats();
if (options.affine()) {
torch::nn::init::ones_(weight);
torch::nn::init::zeros_(bias);
}
}
如果要用自己的方式来初始化权重参数和偏置量参数,可以在包含这些子模块的神经网络中,在创建和注册这些子模块后,用如下方式初始化:
for (auto& m : modules(false))
{
if (m->name() == "torch::nn::Conv2dImpl")
{
printf("init the conv2d parameters.\n");
auto spConv2d = std::dynamic_pointer_cast<torch::nn::Conv2dImpl>(m);
spConv2d->reset_parameters();
torch::nn::init::xavier_normal_(spConv2d->weight);
//torch::nn::init::kaiming_normal_(spConv2d->weight, 0.0, torch::kFanOut, torch::kReLU);
//torch::nn::init::constant_(spConv2d->weight, 1);
if (spConv2d->options.bias())
torch::nn::init::constant_(spConv2d->bias, 0);
}
else if (m->name() == "torch::nn::BatchNorm2dImpl")
{
printf("init the batchnorm2d parameters.\n");
auto spBatchNorm2d = std::dynamic_pointer_cast<torch::nn::BatchNorm2dImpl>(m);
torch::nn::init::constant_(spBatchNorm2d->weight, 1);
torch::nn::init::constant_(spBatchNorm2d->bias, 0);
}
//else if (m->name() == "torch::nn::LinearImpl")
//{
// auto spLinear = std::dynamic_pointer_cast(m);
// torch::nn::init::constant_(spLinear->weight, 1);
// torch::nn::init::constant_(spLinear->bias, 0);
//}
}
具体方法就是遍历所有的子模块,并且check它们是不是对应要改变权重的模块,如果要按照网络模块名称来改参数,可以用如下方式:
for (auto& m : named_modules(std::string(), false))
{
if (m.key() == "C1")
{
printf("init the conv2d parameters.\n");
auto spConv2d = std::dynamic_pointer_cast<torch::nn::Conv2dImpl>(m.value());
spConv2d->reset_parameters();
torch::nn::init::xavier_normal_(spConv2d->weight);
//torch::nn::init::kaiming_normal_(spConv2d->weight, 0.0, torch::kFanOut, torch::kReLU);
//torch::nn::init::constant_(spConv2d->weight, 1);
if (spConv2d->options.bias())
torch::nn::init::constant_(spConv2d->bias, 0);
}
}
如果子模块下面还嵌套子模块,可能要写个递归函数继续往下check,然后进行修改,尤其是那些复杂的网络,比如ResNet,网络中还可能嵌套了Sequential
void VisitAllModules(std::shared_ptr<torch::nn::Module>& m)
{
for (auto& sm : m->named_modules(std::string(), false))
{
if (sm.key() == "C1")
{
printf("init the conv2d parameters.\n");
auto spConv2d = std::dynamic_pointer_cast<torch::nn::Conv2dImpl>(sm.value());
spConv2d->reset_parameters();
torch::nn::init::xavier_normal_(spConv2d->weight);
//torch::nn::init::kaiming_normal_(spConv2d->weight, 0.0, torch::kFanOut, torch::kReLU);
//torch::nn::init::constant_(spConv2d->weight, 1);
if (spConv2d->options.bias())
torch::nn::init::constant_(spConv2d->bias, 0);
}
VisitAllModules(sm.value());
}
}
具体代码参考pytorch1.6的source code ↩︎