【TensorFlow入门完全指南】模型篇·线性回归模型-LMLPHP

首先呢,进行import,对于日常写代码来说,第二行经常写成:import numpy as np,这样会更加简洁。第三行import用于绘图。

【TensorFlow入门完全指南】模型篇·线性回归模型-LMLPHP

定义了学习率、迭代数epoch,以及展示的学习步骤,三个参数。

同时给出了训练用的原始数据,n_samples用来记录一共有多少数据。

【TensorFlow入门完全指南】模型篇·线性回归模型-LMLPHP

【TensorFlow入门完全指南】模型篇·线性回归模型-LMLPHP

这里指明了计算图的输入,W和b是模型的权重矩阵和偏差,目的是要学习一个

\[y=\mathbf{W}x+\mathbf{b}\]

函数。

【TensorFlow入门完全指南】模型篇·线性回归模型-LMLPHP

这里就定义了上述函数。

【TensorFlow入门完全指南】模型篇·线性回归模型-LMLPHP

这里定义了损失函数cost,使用了平方损失。

optimizer是优化器,用来定义训练方法,这里使用了梯度下降。

【TensorFlow入门完全指南】模型篇·线性回归模型-LMLPHP

最后初始化所有的变量。当然我认为最好的初始化还是高斯分布。

【TensorFlow入门完全指南】模型篇·线性回归模型-LMLPHP

这次是在指定的迭代次数里进行循环,每一次迭代,都输入一次zip(x,y)即x与y的元素绑定,数据被完整地喂了num_epoch次。每过几次,就展示一下log。

上面代码的核心代码就是这行:【TensorFlow入门完全指南】模型篇·线性回归模型-LMLPHP,所有的代码都是在不断地运行这行优化代码,请记住sess.run()的这种用法。

最后四行代码用来绘图,效果如下:

【TensorFlow入门完全指南】模型篇·线性回归模型-LMLPHP

下图展示了学习log。

【TensorFlow入门完全指南】模型篇·线性回归模型-LMLPHP

05-11 18:15
查看更多