Deeplearning4j canova示例不起作用,我将eval.stats的输出显示为NaN(准确性)。
import org.slf4j.LoggerFactory;
public class ImageClassifierExample {
public static void main(String[] args) throws IOException, InterruptedException {
// Path to the labeled images
String labeledPath = System.getProperty("user.home")+"/lfw";
List<String> labels = new ArrayList<>();
for(File f : new File(labeledPath).listFiles()) {
labels.add(f.getName());
}
// Instantiating a RecordReader pointing to the data path with the specified
// height and width for each image.
RecordReader recordReader = new ImageRecordReader(28, 28, true,labels);
recordReader.initialize(new FileSplit(new File(labeledPath)));
// Canova to Dl4j
DataSetIterator iter = new RecordReaderDataSetIterator(recordReader, 784,labels.size());
// Creating configuration for the neural net.
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT)
.constrainGradientToUnitNorm(true)
.weightInit(WeightInit.DISTRIBUTION)
.dist(new NormalDistribution(1,1e-5))
.iterations(100).learningRate(1e-3)
.nIn(784).nOut(labels.size())
.visibleUnit(org.deeplearning4j.nn.conf.layers.RBM.VisibleUnit.GAUSSIAN)
.hiddenUnit(org.deeplearning4j.nn.conf.layers.RBM.HiddenUnit.RECTIFIED)
.layer(new org.deeplearning4j.nn.conf.layers.RBM())
.list(4).hiddenLayerSizes(600, 250, 100).override(3, new ConfOverride() {
@Override
public void overrideLayer(int i, NeuralNetConfiguration.Builder builder) {
if (i == 3) {
builder.layer(new org.deeplearning4j.nn.conf.layers.OutputLayer());
builder.activationFunction("softmax");
builder.lossFunction(LossFunctions.LossFunction.MCXENT);
}
}
}).build();
MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.setListeners(Arrays.<IterationListener>asList(new ScoreIterationListener(10)));
// Training
while(iter.hasNext()){
DataSet next = iter.next();
network.fit(next);
}
// Testing -- We're not doing split test and train
// Using the same training data as test.
iter.reset();
Evaluation eval = new Evaluation();
while(iter.hasNext()){
DataSet next = iter.next();
INDArray predict2 = network.output(next.getFeatureMatrix());
eval.eval(next.getLabels(), predict2);
}
System.out.println(eval.stats());
}
}
最佳答案
您的NN配置看起来像是基于真正的dl4j版本。最新发行版本是:
DL4j:0.4-rc3.8
ND4j:0.4-rc3.8
卡诺瓦:0.0.0.14
请尝试使用最新版本