深度学习进度03(变量、api、案例:实现线性回归)
2021-03-01 00:26
标签:unknown 初始化 ror ssi min ict tmp learning flag 高级: 深度学习进度03(变量、api、案例:实现线性回归) 标签:unknown 初始化 ror ssi min ict tmp learning flag 原文地址:https://www.cnblogs.com/dazhi151/p/14436752.html变量OP:
变量的特点:
创建变量:
修改变量的命名空间:
API:
实现线性回归:
案例:
案例代码:
def linear_regression():
"""
自实现一个线性回归
:return:
"""
with tf.compat.v1.variable_scope("prepare_data"):
# 1)准备数据
X = tf.compat.v1.random_normal(shape=[100, 1], name="feature")
y_true = tf.matmul(X, [[0.8]]) + 0.7
with tf.compat.v1.variable_scope("create_model"):
# 2)构造模型
# 定义模型参数 用 变量
weights = tf.Variable(initial_value=tf.compat.v1.random_normal(shape=[1, 1]), name="Weights")
bias = tf.Variable(initial_value=tf.compat.v1.random_normal(shape=[1, 1]), name="Bias")
y_predict = tf.matmul(X, weights) + bias
with tf.compat.v1.variable_scope("loss_function"):
# 3)构造损失函数
error = tf.reduce_mean(tf.square(y_predict - y_true))
with tf.compat.v1.variable_scope("optimizer"):
# 4)优化损失
optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.01).minimize(error)
# 2_收集变量
tf.summary.scalar("error", error)
tf.summary.histogram("weights", weights)
tf.summary.histogram("bias", bias)
# 3_合并变量
merged = tf.compat.v1.summary.merge_all()
# 创建Saver对象
saver = tf.compat.v1.train.Saver()
# 显式地初始化变量
init = tf.compat.v1.global_variables_initializer()
# 开启会话
with tf.compat.v1.Session() as sess:
# 初始化变量
sess.run(init)
# 1_创建事件文件
file_writer = tf.compat.v1.summary.FileWriter("./tmp/linear", graph=sess.graph)
# 查看初始化模型参数之后的值
print("训练前模型参数为:权重%f,偏置%f,损失为%f" % (weights.eval(), bias.eval(), error.eval()))
# 开始训练
# for i in range(100):
# sess.run(optimizer)
# print("第%d次训练后模型参数为:权重%f,偏置%f,损失为%f" % (i+1, weights.eval(), bias.eval(), error.eval()))
#
# # 运行合并变量操作
# summary = sess.run(merged)
# # 将每次迭代后的变量写入事件文件
# file_writer.add_summary(summary, i)
#
# # 保存模型
# if i % 10 ==0:
# saver.save(sess, "./tmp/model/my_linear.ckpt")
# 加载模型
if os.path.exists("./tmp/model/checkpoint"):
saver.restore(sess, "./tmp/model/my_linear.ckpt")
print("训练后模型参数为:权重%f,偏置%f,损失为%f" % (weights.eval(), bias.eval(), error.eval()))
return None
命令行参数:
# 1)定义命令行参数
tf.compat.v1.app.flags.DEFINE_integer("max_step", 100, "训练模型的步数")
tf.compat.v1.app.flags.DEFINE_string("model_dir", "Unknown", "模型保存的路径+模型名字")
# 2)简化变量名
FLAGS = tf.compat.v1.app.flags.FLAGS
def command_demo():
"""
命令行参数演示
:return:
"""
print("max_step:\n", FLAGS.max_step)
print("model_dir:\n", FLAGS.model_dir)
return None