还是以mnist手写字符,lenet5为例。
如上图所示,左边为9个layer层,右边为每层的top blob 的输出(featrue map)的维度。
那么问题来了,训练的参数 wij和bi 存在哪些变量里呢?
1)有需要用BP算法进行训练参数的层(layer),内部都会有一个成员变量layer::blobs_,其中blobs_[0]存放 wij 及 Δwij ,blobs_[1]存放 bi及Δbi .
2) 网络层内部成员变量 Net::blobs_内部存放数据及 δ 值。
在lenet5网络中,只有两个卷积层(conv1,conv2)和两个全链接层(ip1,ip2)内部有需要训练的参数。因此lenet5网络需要8个blobs_存放训练参数,Net内部变量 net::learnable_params_与此8个blobs_共享内存。且,各层的layer::blobs_变量在layer::setup()中分配内存和初始化,在layer::Backward_cpu()中计算 Δwij和Δbi ,后续准备编写自己网络的童鞋需要重写此函数。最后调用Net::update()函数,更新 wij 。
以全链接层为例:
///ip层 blobs_ 开辟内存及初始化
template <typename Dtype>
void InnerProductLayer::LayerSetUp(const vector *>& bottom,
const vector *>& top) {
const int num_output = this->layer_param_.inner_product_param().num_output();
bias_term_ = this->layer_param_.inner_product_param().bias_term();
transpose_ = this->layer_param_.inner_product_param().transpose();
N_ = num_output;
const int axis = bottom[0]->CanonicalAxisIndex(
this->layer_param_.inner_product_param().axis());
// Dimensions starting from "axis" are "flattened" into a single
// length K_ vector. For example, if bottom[0]'s shape is (N, C, H, W),
// and axis == 1, N inner products with dimension CHW are performed.
K_ = bottom[0]->count(axis);
// Check if we need to set up the weights
if (this->blobs_.size() > 0) {
LOG(INFO) << "Skipping parameter initialization";
} else {
if (bias_term_) {
this->blobs_.resize(2);
} else {
this->blobs_.resize(1);
}
// Initialize the weights
vector<int> weight_shape(2);
if (transpose_) {
weight_shape[0] = K_;
weight_shape[1] = N_;
} else {
weight_shape[0] = N_;
weight_shape[1] = K_;
}
this->blobs_[0].reset(new Blob(weight_shape));
// fill the weights
shared_ptr > weight_filler(GetFiller( //// 权重w初始化
this->layer_param_.inner_product_param().weight_filler()));
weight_filler->Fill(this->blobs_[0].get());
// If necessary, intiialize and fill the bias term
if (bias_term_) {
vector<int> bias_shape(1, N_);
this->blobs_[1].reset(new Blob(bias_shape)); //偏置初始化
shared_ptr > bias_filler(GetFiller(
this->layer_param_.inner_product_param().bias_filler()));
bias_filler->Fill(this->blobs_[1].get());
}
} // parameter initialization
this->param_propagate_down_.resize(this->blobs_.size(), true);
}
权重变化量计算:
template <typename Dtype>
void InnerProductLayer::Backward_cpu(const vector *>& top,
const vector<bool>& propagate_down,
const vector *>& bottom) {
if (this->param_propagate_down_[0]) {
const Dtype* top_diff = top[0]->cpu_diff();
const Dtype* bottom_data = bottom[0]->cpu_data();
// Gradient with respect to weight
if (transpose_) {
caffe_cpu_gemm(CblasTrans, CblasNoTrans,
K_, N_, M_,
(Dtype)1., bottom_data, top_diff,
(Dtype)1., this->blobs_[0]->mutable_cpu_diff());
} else {
caffe_cpu_gemm(CblasTrans, CblasNoTrans,
N_, K_, M_,
(Dtype)1., top_diff, bottom_data,
(Dtype)1., this->blobs_[0]->mutable_cpu_diff()); //计算本层参数更新量 △W
}
}
if (bias_term_ && this->param_propagate_down_[1]) {
const Dtype* top_diff = top[0]->cpu_diff();
// Gradient with respect to bias
caffe_cpu_gemv(CblasTrans, M_, N_, (Dtype)1., top_diff,
bias_multiplier_.cpu_data(), (Dtype)1.,
this->blobs_[1]->mutable_cpu_diff()); //计算本层参数更新量 △b
}
if (propagate_down[0]) {
const Dtype* top_diff = top[0]->cpu_diff();
// Gradient with respect to bottom data
if (transpose_) {
caffe_cpu_gemm(CblasNoTrans, CblasTrans,
M_, K_, N_,
(Dtype)1., top_diff, this->blobs_[0]->cpu_data(),
(Dtype)0., bottom[0]->mutable_cpu_diff());
} else {
caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, /// 计算下一层的deta
M_, K_, N_,
(Dtype)1., top_diff, this->blobs_[0]->cpu_data(),
(Dtype)0., bottom[0]->mutable_cpu_diff());
}
}
}
总的训练流程在Solver::Step()函数实现中:
template
void Solver::Step(int iters) {
const int start_iter = iter_;
const int stop_iter = iter_ + iters;
int average_loss = this->param_.average_loss();
losses_.clear();
smoothed_loss_ = 0;
while (iter_ < stop_iter) { //以一个batch为一个周期
// zero-init the params
net_->ClearParamDiffs(); /////将Net 的成员变量param_ 中的diff空间初始化为0
if (param_.test_interval() && iter_ % param_.test_interval() == 0
&& (iter_ > 0 || param_.test_initialization())
&& Caffe::root_solver()) {
TestAll();
if (requested_early_exit_) {
// Break out of the while loop because stop was requested while testing.
break;
}
}
for (int i = 0; i < callbacks_.size(); ++i) {
callbacks_[i]->on_start();
}
const bool display = param_.display() && iter_ % param_.display() == 0;
net_->set_debug_info(display && param_.debug_info());
// accumulate the loss and gradient
Dtype loss = 0;
for (int i = 0; i < param_.iter_size(); ++i) {
loss += net_->ForwardBackward(); /////关键。 前向运算、反向运算
}
loss /= param_.iter_size();
// average the loss across iterations for smoothed reporting
UpdateSmoothedLoss(loss, start_iter, average_loss);
if (display) {
LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_ //////////////// Iteration, loss=
<< ", loss = " << smoothed_loss_;
const vector*>& result = net_->output_blobs();
int score_index = 0;
for (int j = 0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
const string& output_name =
net_->blob_names()[net_->output_blob_indices()[j]];
const Dtype loss_weight =
net_->blob_loss_weights()[net_->output_blob_indices()[j]];
for (int k = 0; k < result[j]->count(); ++k) { //////Train net out///////////////
ostringstream loss_msg_stream;
if (loss_weight) {
loss_msg_stream << " (* " << loss_weight
<< " = " << loss_weight * result_vec[k] << " loss)";
}
LOG_IF(INFO, Caffe::root_solver()) << " Train net output #"
<< score_index++ << ": " << output_name << " = "
<< result_vec[k] << loss_msg_stream.str();
}
}
}//if(display)
for (int i = 0; i < callbacks_.size(); ++i) {
callbacks_[i]->on_gradients_ready();
}
ApplyUpdate(); /////////////// 反向运算完成,统一更新权重值w
// Increment the internal iter_ counter -- its value should always indicate
// the number of times the weights have been updated.
++iter_;
SolverAction::Enum request = GetRequestedAction();
// Save a snapshot if needed.
if ((param_.snapshot()
&& iter_ % param_.snapshot() == 0
&& Caffe::root_solver()) ||
(request == SolverAction::SNAPSHOT)) {
Snapshot();
}
if (SolverAction::STOP == request) {
requested_early_exit_ = true;
// Break out of training loop.
break;
}// if()
}// end while
} // end step()