问题描述
我将对图像执行基于像素的分类.这是我用于训练NN的代码
I am going to perform pixel-based classification on an image. Here is the code I used for training the NN
net = input_data(shape=[None, 1,4])
net = tflearn.lstm(net, 128, return_seq=True)
net = tflearn.lstm(net, 128)
net = tflearn.fully_connected(net, 1, activation='softmax')
net = tflearn.regression(net, optimizer='adam',
loss='categorical_crossentropy')
model = tflearn.DNN(net, tensorboard_verbose=2, checkpoint_path='model.tfl.ckpt')
X_train = np.expand_dims(X_train, axis=1)
model.fit(X_train, y_train, n_epoch=1, validation_set=0.1, show_metric=True,snapshot_step=100)
问题是训练模型后,p.array(model.predict(x_test))的结果仅为1,尽管我希望这是2或3.在一个示例中,我有4类对象,我希望该命令的结果是2到5之间的标签(注意:y_train的int值在2到5之间),但是预测函数的输出还是1.难道是训练阶段的问题吗? /p>
The problem is that after training the model, the result of p.array(model.predict(x_test)) is 1 only although I expected this to be either 2 or 3. In one example where I have had 4 classes of objects and I expected the result of that command to be a label between 2 and 5 (note:y_train has int values between 2 and 5) but again the output of the prediction function is 1. Could that be a problem of training phase?
推荐答案
None
参数用于表示不同培训示例.在您的情况下,由于您使用的是自定义四通道数据集,每个图像总共具有28*28*4
个参数.
The None
parameter is used to denote different training examples. In your case, each image has a total of 28*28*4
parameters, due to the custom four channel dataset you are using.
要使此LSTM正常工作,您应该尝试执行以下操作-
To make this LSTM work, you should try to do the following -
X = np.reshape(X, (-1, 28, 28, 4))
testX = np.reshape(testX, (-1, 28, 28, 4))
net = tflearn.input_data(shape=[None, 28, 28, 4])
当然,(这很重要),请确保reshape()
将对应于单个像素的四个不同通道放在numpy数组的最后一个维度中,而28, 28
对应于单个图像中的像素.
Of course, (this is very important), make sure that reshape()
puts the four different channels corresponding to a single pixel in the last dimension of the numpy array, and the 28, 28
correspond to pixels in a single image.
如果您的图片没有尺寸28*28
,请相应调整这些参数.
In case your images don't have dimension 28*28
, adjust those parameters accordingly.
这篇关于张量流中预测函数的输出错误的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!