摘自《TensorflowGoogle实战》

一、网络示意

二、代码

1、创建网络

 1 import tensorflow as tf
 2 from numpy.random import RandomState
 3 from sklearn.model_selection import train_test_split
 4
 5 batch_size = 8
 6
 7 # 定义网络
 8 w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1))
 9 b1 = tf.Variable(tf.zeros([1, 3], name="bias1"))
10 b2 = tf.Variable(tf.zeros([1], name="bias2"))
11 w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1))
12
13 x = tf.placeholder(tf.float32, shape=(None, 2), name='x-input')
14 y_ = tf.placeholder(tf.float32, shape=(None, 1), name='y-input')
15 a = tf.matmul(x, w1) + b1
16
17 # 这里一定不要写成 tf.sigmoid
18 y_last = tf.nn.sigmoid(tf.matmul(a, w2) + b2)

2、定义损失函数

1 # 损失与准确率
2 loss = tf.losses.sigmoid_cross_entropy(y_, y_last)
3 train_step = tf.train.AdamOptimizer(0.07).minimize(loss)
4
5 correct_prediction = tf.equal(tf.round(y_last), y_)
6 acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

3、造数据

1 # 造数据,X为二维数据例如:[[0.1, 0.8]], 当X的第一项和第二项相加 < 1 时 Y为1, 当X的第一项和第二项相加 >= 1时为 0
2 rdm = RandomState(1)
3 data_set_size = 128000
4 X = rdm.rand(data_set_size, 2)
5 Y = [[int(x1 + x2 < 1)] for (x1, x2) in X]
6
7 X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.3)

4、训练与测试

 1 with tf.Session() as sess:
 2     init_op = tf.global_variables_initializer()
 3     sess.run(init_op)
 4     STEPS = 7000
 5     for i in range(0, 8900 - batch_size, batch_size):
 6         start = i
 7         sess.run(train_step, feed_dict={x: X_train[start: start + batch_size], y_: y_train[start: start + batch_size]})
 8
 9         if i % 800 == 0:
10             # 计算所有数据的交叉熵
11             total_cross_entropy = sess.run(loss, feed_dict={x: X, y_: Y})
12             # 输出交叉熵之和
13             # print("After %d training step(s),cross entropy on all data is %g" % (i, total_cross_entropy))
14
15     acc_ = sess.run(acc, feed_dict={x: X_test, y_: y_test})
16     print("accuracy on test data is ", acc_)
17
18     test = sess.run(y_last, feed_dict={x: X_test, y_: y_test})

5、结果(个人认为tf有很多bug,有几次结果居然是0.5左右,不知道为啥)

accuracy on test data is  0.96091145
02-10 02:06