我正在玩golearn examples文件夹中的knnclassifier_iris.go示例。我用自己的数据集替换了虹膜数据集,只要我对读入的数据中的一定百分比训练数据,所有功能都可以正常工作,并且可以获得一些输出。但是,当我明确提到训练和测试数据集,然后在拟合训练数据集之后在测试数据集上运行预测时,当我尝试打印预测时,结果为零。我不知道为什么我得到一个零值,所以我将非常感谢您的帮助。

我的代码:

package main

import (
    "fmt"
    "github.com/sjwhitworth/golearn/base"
    "github.com/sjwhitworth/golearn/evaluation"
    "github.com/sjwhitworth/golearn/knn"
)

func main() {
    trainData, err := base.ParseCSVToInstances("~/Desktop/churn_train.csv", true)
    if err != nil {
        panic(err)
    }
    fmt.Println(trainData)
    testData, err := base.ParseCSVToInstances("~/Desktop/churn_test.csv", false)
    if err != nil {
        panic(err)
    }
    fmt.Println(trainData)
    fmt.Println(testData)

    //Initialises a new KNN classifier
    cls := knn.NewKnnClassifier("euclidean", 2)
    cls.Fit(trainData)

//Calculates the Euclidean distance and returns the most popular label
    predictions := cls.Predict(testData)
    fmt.Println(predictions) //GETTING <NIL> AS OUTPUT

    // Prints precision/recall metrics
    confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
    if err != nil {
        panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error())) //ERROR CAUSED HERE DUE TO GETTING <NIL>
    }
    fmt.Println(evaluation.GetSummary(confusionMat))

}

最佳答案

(以防万一有人在Google上偶然发现了这一点)。当第二个ParseCSVToInstances生成的实例与第一个ojit_略有不同时,就会出现此问题。为了确保这不是问题,请使用 ParseCSVToTemplatedInstances ,因此

testData, err := base.ParseCSVToInstances("~/Desktop/churn_test.csv", false)

变成
 testData, err := base.ParseCSVToTemplatedInstances("~/Desktop/churn_test.csv", false, trainData)

09-29 22:33