def fit(self, X, y=None, eval_data=None, eval_metric='acc',
epoch_end_callback=None, batch_end_callback=None, kvstore='local', logger=None,
work_load_list=None, monitor=None, eval_end_callback=LogValidationMetricsCallback(),
eval_batch_end_callback=None):
训练集 Training data.——X: 格式:DataIter, or numpy.ndarray/NDArray
If X
is a DataIter
, the name or (if name not available) the position of its outputs should match the corresponding variable names defined in the symbolic graph.
训练集标签 Training set label.——y : 格式:numpy.ndarray/NDArray, optional
If X
is numpy.ndarray
or NDArray
, y
is required to be set.
While y can be 1D or 2D (with 2nd dimension as 1), its first dimension must be the same as X
,
i.e. the number of data points and labels should be equal.
验证集——eval_data : 格式:DataIter or numpy.ndarray/list/NDArray pair
If eval_data is numpy.ndarray/list/NDArray pair, it should be (valid_data, valid_label)
.
验证标准The evaluation metric. ——eval_metric 格式 : metric.EvalMetric or str or callable
This could be the name of evaluation metric or a custom evaluation function that returns statistics based on a minibatch.
回调函数(epoch结束时执行)——** epoch_end_callback** 格式: callable(epoch, symbol, arg_params, aux_states)
可以用来每个epoch保存一下模型(checkpoint)
回调函数(batch结束时执行)——batch_end_callback 格式: callable(epoch)
A callback that is invoked at end of each batch for purposes of printing.
将参数存储到哪
kvstore: KVStore or str, optional
The KVStore or a string kvstore type: 'local', 'dist_sync', 'dist_async'
In default uses 'local', often no need to change for single machiine.
KVStore behavior
- 'local', multi-devices on a single machine, will automatically choose best type.
- 'dist_sync', multiple machines communicating via BSP.
- 'dist_async', multiple machines with asynchronous communication.
是否打印日志
logger : logging logger, optional
When not specified, default logger will be used.
work_load_list : float or int, optional
The list of work load for different devices,
in the same order as `ctx`.
Note
"""
data = self._init_iter(X, y, is_train=True)
eval_data = self._init_eval_iter(eval_data)
if self.sym_gen:
self.symbol = self.sym_gen(data.default_bucket_key) # pylint: disable=no-member
self._check_arguments()
self.kwargs["sym"] = self.symbol
arg_names, param_names, aux_names = \
self._init_params(data.provide_data+data.provide_label)
# setup metric
if not isinstance(eval_metric, metric.EvalMetric):
eval_metric = metric.create(eval_metric)
# create kvstore
(kvstore, update_on_kvstore) = _create_kvstore(
kvstore, len(self.ctx), self.arg_params)
param_idx2name = {}
if update_on_kvstore:
param_idx2name.update(enumerate(param_names))
else:
for i, n in enumerate(param_names):
for k in range(len(self.ctx)):
param_idx2name[i*len(self.ctx)+k] = n
self.kwargs["param_idx2name"] = param_idx2name
# init optmizer
if isinstance(self.optimizer, str):
batch_size = data.batch_size
if kvstore and 'dist' in kvstore.type and '_async' not in kvstore.type:
batch_size *= kvstore.num_workers
optimizer = opt.create(self.optimizer,
rescale_grad=(1.0/batch_size),
**(self.kwargs))
elif isinstance(self.optimizer, opt.Optimizer):
if not optimizer.idx2name:
optimizer.idx2name = param_idx2name.copy()
optimizer = self.optimizer
# do training
_train_multi_device(self.symbol, self.ctx, arg_names, param_names, aux_names,
self.arg_params, self.aux_params,
begin_epoch=self.begin_epoch, end_epoch=self.num_epoch,
epoch_size=self.epoch_size,
optimizer=optimizer,
train_data=data, eval_data=eval_data,
eval_metric=eval_metric,
epoch_end_callback=epoch_end_callback,
batch_end_callback=batch_end_callback,
kvstore=kvstore, update_on_kvstore=update_on_kvstore,
logger=logger, work_load_list=work_load_list, monitor=monitor,
eval_end_callback=eval_end_callback,
eval_batch_end_callback=eval_batch_end_callback,
sym_gen=self.sym_gen)