Callback API
2021-03-02 09:25
标签:重置 cos missing col monit scn iterator monitor space 用于跟踪epoch期间各种状态的回调函数。主要有6个类: 1. [source] 参数: 返回: 2. 这个callback函数用于每隔几个epoch来保存以下模型checkpoint,每个checkpoint由几个binary files组成:一个模型描述文件和一个参数(权重和偏置)文件。模型描述文件名字为prefix-symbol.json,参数文件名字为prefix-epoch_number.params 参数: 返回: 3. callback函数用于每隔几个周期记录训练打印结果 参数: 返回: 4. class 周期性的打印训练速度和评价指标 参数: 例子: 5. class [source] 呈现一个进度条,表明每个epoch内批量的进度。 参数: 例子: 6. class 打印出一个epoch之后的评估结果 整体的一个例子:train_mnist.py:用到了第2个和第4个类: Callback API 标签:重置 cos missing col monit scn iterator monitor space 原文地址:https://www.cnblogs.com/king-lps/p/13060915.htmlCallback API
mxnet.callback.
module_checkpoint
(mod, prefix, period=1, save_optimizer_states=False)
mxnet.callback.
do_checkpoint
(prefix, period=1)
>>> module.fit(iterator, num_epoch=n_epoch,
... epoch_end_callback = mx.callback.do_checkpoint("mymodel", 1))
Start training with [cpu(0)]
Epoch[0] Resetting Data Iterator
Epoch[0] Time cost=0.100
Saved checkpoint to "mymodel-0001.params"
Epoch[1] Resetting Data Iterator
Epoch[1] Time cost=0.060
Saved checkpoint to "mymodel-0002.params"mxnet.callback.
log_train_metric
(period, auto_reset=False)
mxnet.callback.
Speedometer
(batch_size, frequent=50, auto_reset=True)
>>> # Print training speed and evaluation metrics every ten batches. Batch size is one.
>>> module.fit(iterator, num_epoch=n_epoch,
... batch_end_callback=mx.callback.Speedometer(1, 10))
Epoch[0] Batch [10] Speed: 1910.41 samples/sec Train-accuracy=0.200000
Epoch[0] Batch [20] Speed: 1764.83 samples/sec Train-accuracy=0.400000
Epoch[0] Batch [30] Speed: 1740.59 samples/sec Train-accuracy=0.500000mxnet.callback.
ProgressBar
(total, length=80)
>>> progress_bar = mx.callback.ProgressBar(total=2)
>>> mod.fit(data, num_epoch=5, batch_end_callback=progress_bar)
[========--------] 50.0%
[================] 100.0%
mxnet.callback.
LogValidationMetricsCallback
model.fit(train,
begin_epoch=args.load_epoch if args.load_epoch else 0,
num_epoch=args.num_epochs,
eval_data=val,
eval_metric=eval_metrics,
kvstore=kv,
optimizer=args.optimizer,
optimizer_params=optimizer_params,
initializer=initializer,
arg_params=arg_params,
aux_params=aux_params,
batch_end_callback=[mx.callback.Speedometer(args.batch_size, args.disp_batches)], # 每过多少个batch打印一下
epoch_end_callback=mx.callback.do_checkpoint(args.model_prefix , period=args.save_period), # 每过多少period保存模型
allow_missing=True,
monitor=monitor)