Kaldi nnet3的导数单元测试理论依据
2021-03-29 16:24
标签:tab ons out com mode ima math mod war 对参数进行扰动
在Kaldi nnet3的以下单元测试代码中
nnet3/attention-test.cc
kaldi::nnet3::attention::TestAttentionForwardBackward
nnet3/convolution-test.cc
kaldi::nnet3::time_height_convolution::TestDataBackprop
kaldi::nnet3::time_height_convolution::TestParamsBackprop
nnet3/nnet-derivative-test.cc
kaldi::nnet3::UnitTestNnetModelDerivatives
BaseFloat objf_baseline = TraceMatMat(output_deriv, output, kTrans);
in2.SetRandn();
BaseFloat predicted_delta_objf = TraceMatMat(in_deriv, in2, kTrans);
in2.AddMat(1.0, in);
Forward(in2, &output);
BaseFloat objf2 = TraceMatMat(output_deriv, output, kTrans),
observed_delta_objf = objf2 - objf_baseline;
KALDI_ASSERT(observed_delta_objf.ApproxEqual(predicted_delta_objf, 0.1));
有以下假设:
若有:
则反向传播代码是正确的。
以下证明线性变换满足此假设:
假设
?
?
根据微分与导数的关系:
因此,只要上述公式成立,则说明反向传播函数相对正向传播函数是正确的。
?
? 对模型进行扰动
在Kaldi nnet3的以下单元测试代码中
nnet3/nnet-derivative-test.cc:UnitTestNnetModelDerivatives()
有以下假设:
若有:
??是模型参数
将模型参数视为函数参数
则模型的前向传播和反向传播代码是正确的。
以下证明线性变换满足此假设:
假设
?
?
?
? ?
? Kaldi nnet3的导数单元测试理论依据 标签:tab ons out com mode ima math mod war 原文地址:https://www.cnblogs.com/JarvanWang/p/12606950.html