LogisticRegressionWithSGD

LogisticRegressionWithSGD

我正在尝试创建一个LogisticRegression模型(LogisticRegressionWithSGD),但是它得到一个错误

org.apache.spark.SparkException: Input validation failed.


如果我给它二进制输入(0,1而不是0,1,2),它将成功。

输入示例:

parsed_data = [LabeledPoint(0.0, [4.6,3.6,1.0,0.2]),
LabeledPoint(0.0, [5.7,4.4,1.5,0.4]),
LabeledPoint(1.0, [6.7,3.1,4.4,1.4]),
LabeledPoint(0.0, [4.8,3.4,1.6,0.2]),
LabeledPoint(2.0, [4.4,3.2,1.3,0.2])]


码:
    model = LogisticRegressionWithSGD.train(parsed_data)

Spark中的Logistic回归模型是否应该仅用于二进制分类?

最佳答案

尽管从文档中还不清楚(您必须深入研究source code才能实现它),但LogisticRegressionWithSGD仅适用于二进制数据。对于多项式回归,应使用LogisticRegressionWithLBFGS

 from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel, LogisticRegressionWithSGD
 from pyspark.mllib.regression import LabeledPoint
 parsed_data = [LabeledPoint(0.0, [4.6,3.6,1.0,0.2]),
                LabeledPoint(0.0, [5.7,4.4,1.5,0.4]),
                LabeledPoint(1.0, [6.7,3.1,4.4,1.4]),
                LabeledPoint(0.0, [4.8,3.4,1.6,0.2]),
                LabeledPoint(2.0, [4.4,3.2,1.3,0.2])]

 model = LogisticRegressionWithSGD.train(sc.parallelize(parsed_data)) # gives error:
 # org.apache.spark.SparkException: Input validation failed.

 model = LogisticRegressionWithLBFGS.train(sc.parallelize(parsed_data), numClasses=3)  # works OK

08-24 14:17