tensorflow API _ 4 (优化器配置)

2021-07-15 21:12

阅读:472

"""Configures the optimizer used for training.

Args:
learning_rate: A scalar or `Tensor` learning rate.

Returns:
An instance of an optimizer.

Raises:
ValueError: if FLAGS.optimizer is not recognized.
"""
if FLAGS.optimizer == ‘adadelta‘:
optimizer = tf.train.AdadeltaOptimizer(
learning_rate,
rho=FLAGS.adadelta_rho,
epsilon=FLAGS.opt_epsilon)
elif FLAGS.optimizer == ‘adagrad‘:
optimizer = tf.train.AdagradOptimizer(
learning_rate,
initial_accumulator_value=FLAGS.adagrad_initial_accumulator_value)
elif FLAGS.optimizer == ‘adam‘:
optimizer = tf.train.AdamOptimizer(
learning_rate,
beta1=FLAGS.adam_beta1,
beta2=FLAGS.adam_beta2,
epsilon=FLAGS.opt_epsilon)
elif FLAGS.optimizer == ‘ftrl‘:
optimizer = tf.train.FtrlOptimizer(
learning_rate,
learning_rate_power=FLAGS.ftrl_learning_rate_power,
initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value,
l1_regularization_strength=FLAGS.ftrl_l1,
l2_regularization_strength=FLAGS.ftrl_l2)
elif FLAGS.optimizer == ‘momentum‘:
optimizer = tf.train.MomentumOptimizer(
learning_rate,
momentum=FLAGS.momentum,
name=‘Momentum‘)
elif FLAGS.optimizer == ‘rmsprop‘:
optimizer = tf.train.RMSPropOptimizer(
learning_rate,
decay=FLAGS.rmsprop_decay,
momentum=FLAGS.rmsprop_momentum,
epsilon=FLAGS.opt_epsilon)
elif FLAGS.optimizer == ‘sgd‘:
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
else:
raise ValueError(‘Optimizer [%s] was not recognized‘, FLAGS.optimizer)
return optimizer


评论


亲,登录后才可以留言!