Callback API

时间:2020-06-07 16:41:28   收藏:0   阅读:102

Callback API

用于跟踪epoch期间各种状态的回调函数。主要有6个类:

1. mxnet.callback.module_checkpoint(modprefixperiod=1save_optimizer_states=False)

[source]

参数:

返回:

 

 

2. mxnet.callback.do_checkpoint(prefixperiod=1)

这个callback函数用于每隔几个epoch来保存以下模型checkpoint,每个checkpoint由几个binary files组成:一个模型描述文件和一个参数(权重和偏置)文件。模型描述文件名字为prefix-symbol.json,参数文件名字为prefix-epoch_number.params

参数:

返回:

>>> module.fit(iterator, num_epoch=n_epoch,
... epoch_end_callback  = mx.callback.do_checkpoint("mymodel", 1))
Start training with [cpu(0)] Epoch[0] Resetting Data Iterator Epoch[0] Time cost
=0.100 Saved checkpoint to "mymodel-0001.params" Epoch[1] Resetting Data Iterator Epoch[1] Time cost=0.060 Saved checkpoint to "mymodel-0002.params"

 

 

3. mxnet.callback.log_train_metric(periodauto_reset=False)

callback函数用于每隔几个周期记录训练打印结果

参数:

返回:

 

 

4. class mxnet.callback.Speedometer(batch_sizefrequent=50auto_reset=True)

周期性的打印训练速度和评价指标

参数:

例子:

>>> # Print training speed and evaluation metrics every ten batches. Batch size is one.
>>> module.fit(iterator, num_epoch=n_epoch,
... batch_end_callback=mx.callback.Speedometer(1, 10))
Epoch[0] Batch [
10] Speed: 1910.41 samples/sec Train-accuracy=0.200000 Epoch[0] Batch [20] Speed: 1764.83 samples/sec Train-accuracy=0.400000 Epoch[0] Batch [30] Speed: 1740.59 samples/sec Train-accuracy=0.500000

 

 

 

5. class mxnet.callback.ProgressBar(totallength=80)

[source]

呈现一个进度条,表明每个epoch内批量的进度。

参数:

例子:

>>> progress_bar = mx.callback.ProgressBar(total=2)
>>> mod.fit(data, num_epoch=5, batch_end_callback=progress_bar)
[========--------] 50.0%
[================] 100.0%

 

 

 

6. class mxnet.callback.LogValidationMetricsCallback

打印出一个epoch之后的评估结果

 

 

整体的一个例子:train_mnist.py:用到了第2个和第4个类:

    model.fit(train,
              begin_epoch=args.load_epoch if args.load_epoch else 0,
              num_epoch=args.num_epochs,
              eval_data=val,
              eval_metric=eval_metrics,
              kvstore=kv,
              optimizer=args.optimizer,
              optimizer_params=optimizer_params,
              initializer=initializer,
              arg_params=arg_params,
              aux_params=aux_params,
              batch_end_callback=[mx.callback.Speedometer(args.batch_size, args.disp_batches)],        # 每过多少个batch打印一下
              epoch_end_callback=mx.callback.do_checkpoint(args.model_prefix , period=args.save_period),      # 每过多少period保存模型
              allow_missing=True,
              monitor=monitor)
评论(0
© 2014 mamicode.com 版权所有 京ICP备13008772号-2  联系我们:gaon5@hotmail.com
迷上了代码!