问题描述
我想用网格搜索和 spark 交叉验证来调整我的模型.在 spark 中,它必须将基础模型放入管道中,管道的office demo 使用LogistictRegression
作为基础模型,它可以是新的对象.但是,RandomForest
模型不能被客户端代码new,因此它似乎无法在管道 api 中使用 RandomForest
.我不想重新创建一个轮子,所以有人可以给一些建议吗?谢谢
I want to tunning my model with grid search and cross validation with spark. In the spark, it must put the base model in a pipeline, the office demo of pipeline use the LogistictRegression
as an base model, which can be new as an object. However, the RandomForest
model cannot be new by client code, so it seems not be able to use RandomForest
in the pipeline api. I don't want to recreate an wheel, so can anybody give some advice?Thanks
推荐答案
嗯,这是真的,但你只是试图使用错误的类.您应该使用 ml.classification.RandomForestClassifier
而不是 mllib.tree.RandomForest
.这是一个基于 来自 MLlib 文档的示例.
Well, that is true but you simply trying to use a wrong class. Instead of mllib.tree.RandomForest
you should use ml.classification.RandomForestClassifier
. Here is an example based on the one from MLlib docs.
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLUtils
import sqlContext.implicits._
case class Record(category: String, features: Vector)
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainData, testData) = (splits(0), splits(1))
val trainDF = trainData.map(lp => Record(lp.label.toString, lp.features)).toDF
val testDF = testData.map(lp => Record(lp.label.toString, lp.features)).toDF
val indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("label")
val rf = new RandomForestClassifier()
.setNumTrees(3)
.setFeatureSubsetStrategy("auto")
.setImpurity("gini")
.setMaxDepth(4)
.setMaxBins(32)
val pipeline = new Pipeline()
.setStages(Array(indexer, rf))
val model = pipeline.fit(trainDF)
model.transform(testDF)
有一件事我在这里想不通.据我所知应该可以直接使用从 LabeledPoints
中提取的标签,但由于某种原因它不起作用并且 pipeline.fit
引发 IllegalArgumentExcetion代码>:
There is one thing I couldn't figure out here. As far as I can tell it should be possible to use labels extracted from LabeledPoints
directly, but for some reason it doesn't work and pipeline.fit
raises IllegalArgumentExcetion
:
RandomForestClassifier 的输入带有无效的标签列标签,但没有指定类的数量.
这就是 StringIndexer
的丑陋技巧.应用后我们得到必需的属性 ({"vals":["1.0","0.0"],"type":"nominal","name":"label"}
) 但有些类在ml
没有它似乎也能正常工作.
Hence the ugly trick with StringIndexer
. After applying we get required attributes ({"vals":["1.0","0.0"],"type":"nominal","name":"label"}
) but some classes in ml
seem to work just fine without it.
这篇关于如何在 Spark Pipeline 中使用 RandomForest的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!