Module API
2021-03-02 00:28
标签:red The https merge 情况 更新 lin nbsp logs module或简写为mod,提供一个用于执行Symbol算的中高级接口,可理解为module是执行Symbol定义好的程序的机器。 module.Module接受Symbol作为输入: 关于module的一套训练流程见这里。本节的目的是选择一些常用module API,包括一些重要的属性和方法做个分析。 module包提供了以下几个module:最主要的还是第二个。
显然BaseModule是其他所有module的基类, 基类提供以下方法: 1. 初始化空间:
2. 参数操作:
3. 训练预测
4. 前反向传播
5. 参数更新
6. 输入输出
7. 其他
以上这些方法是所有类共有的,而Module类自己还有下面的内置方法:
这么多方法选择一些常用重要的方法进行介绍。 module表示计算组件。可把module看作是一台计算机器。模块可以执行向前和向后传递并更新模型中的参数。 before binding:为了使module产生交互,它必须能够在初始状态(bind之前)已知以下信息: after binding:绑定后,module应能提供以下更丰富的信息: binded:bool,指示是否已分配计算所需的内存缓冲区。 for_training:模块是否绑定进行训练。 params_initialized:bool,指示此模块的参数是否已初始化。 optimizer_initialized:bool,指示是否定义并初始化了优化器。 inputs_need_grad:bool,指示是否需要相输入数据的梯度。在实现模块组合时可能很有用。 data_shapes:(name、shape)的列表。理论上,由于内存是分配的,可以直接提供数据数组。但在数据并行的情况下,数据数组的形状可能与从外部世界看的不同。 get_params():返回一个(arg_params,aux_params)的元组。每一个都是一个name到NDArray的映射。由于NDArray总是使用CPU,用于计算的实际参数可能存在于其他设备(GPU)上,此函数将检索最新参数。 bind():为计算准备环境。 forward(data_batch):前向操作。 当这些中间层API被正确实现时,以下高级API将自动可用于模块: 1. 这个方法用于预测eval_data的结果,并根据eval_metric提供的指标进行评估 一个例子: 2. 这方法主要是得到eval_data的测试结果。 3. 这个是最重要的方法,用于训练网络。 4. get_params() 返回一对字典。类型分别是arg_params和aux_params。每个字典都是从参数映射到NDArray: 5. init_params和set_params 前者初始化参数和辅助参数的状态,后者给参数和辅助参数赋值。 6. load_params和save_params 载入和保存参数 7. forawrd(data_batch, is_train=None)和backwar(out_grads=None) data_batch是DataBatch类型。 上面在前向和反向传播时分别用到了get_outputs(merge_multi_context=True)和get_input_grads(merge_multi_context=True)两个函数: 前者得到前向计算后的输出,后者得到反向传播后关于输入的梯度。 8. 指定了这个还需指定fit里面那个optimizer吗??? 9. update()和update_metric(eval_metric, labels,pre_sliced=False) update根据已配置的优化器和上一个前向-反向批处理中计算的梯度更新参数。 update_metric对上次前向计算的输出求值并累积评估metric。 10. 非常重要,不用fit这个高级api的话,就需要bind来搞起训练。 Module API 标签:red The https merge 情况 更新 lin nbsp logs 原文地址:https://www.cnblogs.com/king-lps/p/13066148.htmldata = mx.sym.Variable(‘data‘)
fc1 = mx.sym.FullyConnected(data, name=‘fc1‘, num_hidden=128)
act1 = mx.sym.Activation(fc1, name=‘relu1‘, act_type="relu")
fc2 = mx.sym.FullyConnected(act1, name=‘fc2‘, num_hidden=10)
out = mx.sym.SoftmaxOutput(fc2, name = ‘softmax‘)
mod = mx.mod.Module(out) # create a module by given a Symbol 根据symbol建立module
一个module有几个状态:
label_shapes:(name、shape)的列表。如果模块不需要标签(如顶部不包含loss函数),或者模块未绑定以进行训练,则此值可能为[]。
outpu_shapes:(name、shape)的列表。
set_params(arg_params,aux_params):为执行计算的设备分配参数。
init_params(…):一个更灵活的接口来分配或初始化参数。
init_optimizer():安装用于参数更新的优化器。
prepare():根据当前数据批准备模块。
backward(out_grads=None):反向操作。
update():根据安装的优化器更新参数。
get_outputs():获取上一个前向操作的输出。
get_input_grads():获取与上一个后向操作中计算的输入的梯度。
update_metric(metric,labels,pre_sliced=False):更新之前前向传播结果的性能度量,就是metric。
score
(eval_data, eval_metric, num_batch=None, batch_end_callback=None, score_end_callback=None, reset=True, epoch=0, sparse_row_id_fn=None)[source]
# An example of using score for prediction.
# Evaluate accuracy on val_dataiter
metric = mx.metric.Accuracy()
mod.score(val_dataiter, metric)
mod.score(val_dataiter, [‘mse‘, ‘acc‘])
predict
(eval_data, num_batch=None, merge_batches=True, reset=True, always_output_list=False, sparse_row_id_fn=None)# An example of using `predict` for prediction.
# Predict on the first 10 batches of val_dataiter
mod.predict(eval_data=val_dataiter, num_batch=10)
fit
(train_data, eval_data=None, eval_metric=‘acc‘, epoch_end_callback=None, batch_end_callback=None, kvstore=‘local‘, optimizer=‘sgd‘, optimizer_params=((‘learning_rate‘, 0.01), ), eval_end_callback=None, eval_batch_end_callback=None, initializer=, arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None, sparse_row_id_fn=None)[source]
# An example of using fit for training.
# Assume training dataIter and validation dataIter are ready
# Assume loading a previously checkpointed model
sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer=‘sgd‘,
optimizer_params={‘learning_rate‘:0.01, ‘momentum‘: 0.9},
arg_params=arg_params, aux_params=aux_params,
eval_metric=‘acc‘, num_epoch=10, begin_epoch=3)
# An example of getting module parameters.
print mod.get_params()
init_params
(initializer=, arg_params=None, aux_params=None, allow_missing=False, force_init=False, allow_extra=False)set_params
(arg_params, aux_params, allow_missing=False, force_init=True, allow_extra=False)[source]# An example of initializing module parameters.
mod.init_params()
# An example of setting module parameters.
sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, n_epoch_load)
mod.set_params(arg_params=arg_params, aux_params=aux_params)
# An example of saving module parameters.
mod.save_params(‘myfile‘)
# An example of loading module parameters.
mod.load_params(‘myfile‘)
import mxnet as mx
from collections import namedtuple
Batch = namedtuple(‘Batch‘, [‘data‘])
data = mx.sym.Variable(‘data‘)
out = data * 2
mod = mx.mod.Module(symbol=out, label_names=None)
mod.bind(data_shapes=[(‘data‘, (1, 10))])
mod.init_params()
data1 = [mx.nd.ones((1, 10))]
mod.forward(Batch(data1))
print mod.get_outputs()[0].asnumpy()
# Forward with data batch of different shape
data2 = [mx.nd.ones((3, 5))]
mod.forward(Batch(data2))
print mod.get_outputs()[0].asnumpy()
# An example of backward computation.
mod.backward()
print mod.get_input_grads()[0].asnumpy()
]
init_optimizer
(kvstore=‘local‘, optimizer=‘sgd‘, optimizer_params=((‘learning_rate‘, 0.01), ), force_init=False)# An example of initializing optimizer.
mod.init_optimizer(optimizer=‘sgd‘, optimizer_params=((‘learning_rate‘, 0.005),))
# An example of updating module parameters.
mod.init_optimizer(kvstore=‘local‘, optimizer=‘sgd‘,
optimizer_params=((‘learning_rate‘, 0.01), ))
mod.backward()
mod.update()
print mod.get_params()[0][‘fc3_weight‘].asnumpy()
]
# An example of updating evaluation metric.
mod.forward(data_batch)
mod.update_metric(metric, data_batch.label)
bind
(data_shapes, label_shapes=None, for_training=True, inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req=‘write‘)# An example of binding symbols.
mod.bind(data_shapes=[(‘data‘, (1, 10, 10))])
# Assume train_iter is already created.
mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
上一篇:使用高德API-初级应用
下一篇:网页中如何调用WIN本地程序