监督学习
0.线性回归(加L1、L2正则化)
from __future__ import print_function from pyspark.ml.regression import LinearRegression from pyspark.sql import SparkSession spark = SparkSession\ .builder\ .appName("LinearRegressionWithElasticNet")\ .getOrCreate() # 加载数据 training = spark.read.format("libsvm")\ .load("data/mllib/sample_linear_regression_data.txt") lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) # 拟合模型 lrModel = lr.fit(training) # 输出系数和截距 print("Coefficients: %s" % str(lrModel.coefficients)) print("Intercept: %s" % str(lrModel.intercept)) # 模型信息总结输出 trainingSummary = lrModel.summary print("numIterations: %d" % trainingSummary.totalIterations) print("objectiveHistory: %s" % str(trainingSummary.objectiveHistory)) trainingSummary.residuals.show() print("RMSE: %f" % trainingSummary.rootMeanSquaredError) print("r2: %f" % trainingSummary.r2) spark.stop()
结果:
Coefficients: [0.0,0.322925166774,-0.343854803456,1.91560170235,0.0528805868039,0.76596272046,0.0,-0.151053926692,-0.215879303609,0.220253691888] Intercept: 0.159893684424 numIterations: 7 objectiveHistory: [0.49999999999999994, 0.4967620357443381, 0.4936361664340463, 0.4936351537897608, 0.4936351214177871, 0.49363512062528014, 0.4936351206216114] +--------------------+ | residuals| +--------------------+ | -9.889232683103197| | 0.5533794340053554| | -5.204019455758823| | -20.566686715507508| | -9.4497405180564| | -6.909112502719486| | -10.00431602969873| | 2.062397807050484| | 3.1117508432954772| | -15.893608229419382| | -5.036284254673026| | 6.483215876994333| | 12.429497299109002| | -20.32003219007654| | -2.0049838218725005| | -17.867901734183793| | 7.646455887420495| | -2.2653482182417406| |-0.10308920436195645| | -1.380034070385301| +--------------------+ only showing top 20 rows RMSE: 10.189077 r2: 0.022861
1.广义线性模型
from __future__ import print_function from pyspark.sql import SparkSession from pyspark.ml.regression import GeneralizedLinearRegression spark = SparkSession\ .builder\ .appName("GeneralizedLinearRegressionExample")\ .getOrCreate() # 加载数据 dataset = spark.read.format("libsvm")\ .load("data/mllib/sample_linear_regression_data.txt") glr = GeneralizedLinearRegression(family="gaussian", link="identity", maxIter=10, regParam=0.3) # 拟合模型 model = glr.fit(dataset) # 输出系数和截距 print("Coefficients: " + str(model.coefficients)) print("Intercept: " + str(model.intercept)) # 模型信息总结与输出 summary = model.summary print("Coefficient Standard Errors: " + str(summary.coefficientStandardErrors)) print("T Values: " + str(summary.tValues)) print("P Values: " + str(summary.pValues)) print("Dispersion: " + str(summary.dispersion)) print("Null Deviance: " + str(summary.nullDeviance)) print("Residual Degree Of Freedom Null: " + str(summary.residualDegreeOfFreedomNull)) print("Deviance: " + str(summary.deviance)) print("Residual Degree Of Freedom: " + str(summary.residualDegreeOfFreedom)) print("AIC: " + str(summary.aic)) print("Deviance Residuals: ") summary.residuals().show() spark.stop()
结果:
Coefficients: [0.0105418280813,0.800325310056,-0.784516554142,2.36798871714,0.501000208986,1.12223511598,-0.292682439862,-0.498371743232,-0.603579718068,0.672555006719] Intercept: 0.145921761452 Coefficient Standard Errors: [0.7950428434287478, 0.8049713176546897, 0.7975916824772489, 0.8312649247659919, 0.7945436200517938, 0.8118992572197593, 0.7919506385542777, 0.7973378214726764, 0.8300714999626418, 0.7771333489686802, 0.463930109648428] T Values: [0.013259446542269243, 0.9942283563442594, -0.9836067393599172, 2.848657084633759, 0.6305509179635714, 1.382234441029355, -0.3695715687490668, -0.6250446546128238, -0.7271418403049983, 0.8654306337661122, 0.31453393176593286] P Values: [0.989426199114056, 0.32060241580811044, 0.3257943227369877, 0.004575078538306521, 0.5286281628105467, 0.16752945248679119, 0.7118614002322872, 0.5322327097421431, 0.467486325282384, 0.3872259825794293, 0.753249430501097] Dispersion: 105.609883568 Null Deviance: 53229.3654339 Residual Degree Of Freedom Null: 500 Deviance: 51748.8429484 Residual Degree Of Freedom: 490 AIC: 3769.18958718 Deviance Residuals: +-------------------+ | devianceResiduals| +-------------------+ |-10.974359174246889| | 0.8872320138420559| | -4.596541837478908| |-20.411667435019638| |-10.270419345342642| |-6.0156058956799905| |-10.663939415849267| | 2.1153960525024713| | 3.9807132379137675| |-17.225218272069533| | -4.611647633532147| | 6.4176669407698546| | 11.407137945300537| | -20.70176540467664| | -2.683748540510967| |-16.755494794232536| | 8.154668342638725| |-1.4355057987358848| |-0.6435058688185704| | -1.13802589316832| +-------------------+ only showing top 20 rows