我正在尝试运行决策树分类器,标签具有双重模式,并且值从-20到+20

import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import java.io.File`

     val dtModelPath = s"file:///home/parv/spark/examples/src/main/scala/org/apache/spark/examples/ml/ dtModel"

     val dtModel= {
     val dtGridSearch = for (
     dtImpurity<- Array("entropy", "gini");
     dtDepth<- Array(3, 5))
     yield {
     println(s"Training decision tree: impurity $dtImpurity,depth: $dtDepth")
     val dtModel = new DecisionTreeClassifier()
     .setFeaturesCol(idf.getOutputCol)
     .setLabelCol("value")
     .setImpurity(dtImpurity)
     .setMaxDepth(dtDepth)
     .setMaxBins(10)
     .setSeed(42)
     .setCacheNodeIds(true)
     .fit(trainData)
     val dtPrediction = dtModel.transform(testData)
     val dtAUC = new BinaryClassificationEvaluator().setLabelCol("value").evaluate(dtPrediction)
     println(s" DT AUC on test data: $dtAUC")
     ((dtImpurity, dtDepth), dtModel, dtAUC)
     }
     println(dtGridSearch.sortBy(-_._3).take(5).mkString("\n"))
     val bestModel = dtGridSearch.sortBy(-_._3).head._2
     bestModel.write.overwrite.save(dtModelPath)
     bestModel
     }


我遇到错误


  降雨决策树:杂质熵,深度:3 [阶段
  31346:============>(47 + [Stage
  31346:===============>(61 + [Stage
  31346:======================>(87 + [Stage
  31346:============================>(111 + [Stage
  31346:==================================>(135 + [阶段
  31346:=========================================>(166 + [Stage
  31346:================================================= >(192 +
                                                                         18/03/30 01:06:18 WARN执行器:1个块锁未被释放
  TID = 63510:[rdd_62747_0] 18/03/30 01:06:18错误执行程序:异常
  在阶段31353.0中的任务7.0中(TID 63518)
  java.lang.IllegalArgumentException:要求失败:分类器为
  给定具有无效标签-6.0的数据集。标签必须是的整数
  范围[0,1,...,44),其中numClasses = 44。在
  scala.Predef $ .require(Predef.scala:224)

最佳答案

看来您给分类器提供了无效的标签。
它说Classifier was given dataset with invalid label -6.0. Labels must be integers in range [0, 1, ..., 44)

我会检查标签,例如

df.select($"labels").distinct.show(100)
df.filter($"labels" < 0).show()

08-28 11:10