我正在尝试创建一个LogisticRegression模型(LogisticRegressionWithSGD),但是它得到一个错误
org.apache.spark.SparkException: Input validation failed.
如果我给它二进制输入(0,1而不是0,1,2),它将成功。
输入示例:
parsed_data = [LabeledPoint(0.0, [4.6,3.6,1.0,0.2]),
LabeledPoint(0.0, [5.7,4.4,1.5,0.4]),
LabeledPoint(1.0, [6.7,3.1,4.4,1.4]),
LabeledPoint(0.0, [4.8,3.4,1.6,0.2]),
LabeledPoint(2.0, [4.4,3.2,1.3,0.2])]
码:
model = LogisticRegressionWithSGD.train(parsed_data)
Spark中的Logistic回归模型是否应该仅用于二进制分类?
最佳答案
尽管从文档中还不清楚(您必须深入研究source code才能实现它),但LogisticRegressionWithSGD
仅适用于二进制数据。对于多项式回归,应使用LogisticRegressionWithLBFGS
:
from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel, LogisticRegressionWithSGD
from pyspark.mllib.regression import LabeledPoint
parsed_data = [LabeledPoint(0.0, [4.6,3.6,1.0,0.2]),
LabeledPoint(0.0, [5.7,4.4,1.5,0.4]),
LabeledPoint(1.0, [6.7,3.1,4.4,1.4]),
LabeledPoint(0.0, [4.8,3.4,1.6,0.2]),
LabeledPoint(2.0, [4.4,3.2,1.3,0.2])]
model = LogisticRegressionWithSGD.train(sc.parallelize(parsed_data)) # gives error:
# org.apache.spark.SparkException: Input validation failed.
model = LogisticRegressionWithLBFGS.train(sc.parallelize(parsed_data), numClasses=3) # works OK