我正在使用Sirajalogy的以下代码。 https://github.com/llSourcell/How_to_use_Tensorflow_for_classification-LIVE/blob/master/demo.ipynbIt
已对其进行了修改,以接受我自己的.csv文件,其尺寸与示例中使用的尺寸不同。

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf          # Fire from the gods
dataframe = pd.read_csv("jfkspxs.csv")
dataframe = dataframe.drop(["Field6", "Field9", "rowid"], axis=1)

inputX = dataframe.loc[:, ['Field2', 'Field3', 'Field4', 'Field5', 'Field7', 'Field8', 'Field10']].as_matrix()
inputY = dataframe.loc[:, ["y1"]].as_matrix()

learning_rate = 0.001
training_epochs = 2000
display_step = 50
n_samples = inputY.size

x = tf.placeholder(tf.float32, [None, 7])
W = tf.Variable(tf.zeros([7, 1]))
b = tf.Variable(tf.zeros([1]))

y_values = tf.add(tf.matmul(x, W), b)
y = tf.nn.softmax(y_values)
y_ = tf.placeholder(tf.float32, [None,1])

cost = tf.reduce_sum(tf.pow(y_ - y, 2))/(2*n_samples)
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

for i in range(training_epochs):
    sess.run(optimizer, feed_dict={x: inputX, y_: inputY}))
if (i) % display_step == 0:
        cc = sess.run(cost, feed_dict={x: inputX, y_:inputY})
        print ("Training step:", '%04d' % (i), "cost=", "{:.9f}".format(cc))


该代码正在运行,但是会产生以下成本更新。

Training step: 0000 cost= 0.271760166
Training step: 0050 cost= 0.271760166
Training step: 0100 cost= 0.271760166
Training step: 0150 cost= 0.271760166
Training step: 0200 cost= 0.271760166
Training step: 0250 cost= 0.271760166
Training step: 0300 cost= 0.271760166
Training step: 0350 cost= 0.271760166
etc.


问题:为什么成本在每个培训步骤中都没有更新?
谢谢!

最佳答案

问题:您的渐变为零,因此权重不变。您向softmax提供单一尺寸(batch_size,1)。这使softmax的输出为常数(1)。这使其梯度为零。

解决方案:

如果您要进行逻辑回归,请使用tf.nn.sigmoid_cross_entropy_with_logits(y_values, y_)

如果您要进行线性回归,请使用(即不要使用softmax):
cost = tf.reduce_sum(tf.pow(y_ - y_values, 2))/(2*n_samples)

如果您坚持混合使用softmax和MSE,请使用以下内容代替softmax:
y = tf.reciprocal(1 + tf.exp(-y_values))

关于python - 成本未在 tensorflow 中更新,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/41886136/

10-12 22:21