我正在尝试编写一个程序来预测是否患有恶性肿瘤或良性肿瘤

数据集来自:https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+%28Prognostic%29

这是我的代码,我的准确率约为65%,这不比掷硬币好。任何帮助,将不胜感激

import tensorflow as tf
import pandas as pd
import numpy as np


df = pd.read_csv(r'D:\wholedesktop\logisticReal.txt')
df.drop(['id'], axis=1, inplace=True)

x_data = np.array(df.drop(['class'], axis=1))
x_data = x_data.astype(np.float64)
y = df['class']
y.replace(2, 0, inplace=True)
y.replace(4, 1, inplace=True)
y_data = np.array(y)
# y shape = 681,1
# x shape = 681,9

x = tf.placeholder(name='x', dtype=np.float32)
y = tf.placeholder(name='y', dtype=np.float32)

w = tf.Variable(dtype=np.float32, initial_value=np.random.random((9, 1)))
b = tf.Variable(dtype=np.float32, initial_value=np.random.random((1, 1)))

y_ = (tf.add(tf.matmul(x, w), b))
error = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=y_, labels=y))
goal = tf.train.GradientDescentOptimizer(0.05).minimize(error)

prediction = tf.round(tf.sigmoid(y_))
correct = tf.cast(tf.equal(prediction, y), dtype=np.float64)
accuracy = tf.reduce_mean(correct)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(2000):
        sess.run(goal, feed_dict={x: x_data, y: y_data})
        print(i, sess.run(accuracy, feed_dict={x: x_data, y: y_data}))

    weight = sess.run(w)
    bias = sess.run(b)
    print(weight)
    print(bias)

最佳答案

您的神经网络只有一个图层,因此它能做的最好就是将一条直线拟合到您的数据中,以分隔不同的类。对于一般(高维)数据集,这是远远不够的。 (深度)神经网络的功能在于神经元的许多层之间的连通性。在您的示例中,您可以通过将matmul的输出传递到具有不同权重和偏差的新matmul来手动添加更多层,或者可以使用contrib.layers集合使其更加简洁:

x = tf.placeholder(name='x', dtype=np.float32)
fc1 = tf.contrib.layers.fully_connected(inputs=x, num_outputs=16, activation_fn=tf.nn.relu)
fc2 = tf.contrib.layers.fully_connected(inputs=fc1, num_outputs=32, activation_fn=tf.nn.relu)
fc3 = tf.contrib.layers.fully_connected(inputs=fc2, num_outputs=64, activation_fn=tf.nn.relu)


诀窍是将输出从一层作为输入传递到下一层。随着添加的图层越来越多,精度会提高(可能是由于过度拟合,请使用dropout进行纠正)。

关于python - tensorflow 逻辑回归的准确性非常差,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/50386210/

10-12 21:55