我正在尝试扩展或代理 org.apache.spark.ml.clustering.KMeans 类,以便授权 K=1。
class K1Means extends Estimator{
final val kmeans = new KMeans()
val k = 1
override def setK(value:Int) {
if(value >1){
this.kmeans.setK(value)
}
}
override def fit(dataset: DataFrame): KMeansModel = {
if(this.k == 1){
/* super specific to my case */
val avg_sample = Vectors.zeros(
dataset
.select("scaledFeatures")
.take(1)(0)(0) // first row
.asInstanceOf[DenseVector] // was of type Any
.size
) // with the scaling the average value of each column is 0
var centers_local = Array(avg_sample)
return new KMeansModel(centers_local)
}
else{
return this.kmeans.fit(dataset)
}
}
// every method then calls this.kmeans.method()
}
我试过这个,但是
new KMeansModel(centers_local)
没有被授权,因为 KMeansModel 有一个私有(private)构造函数。这是错误消息:
constructor KMeansModel in class KMeansModel cannot be accessed in class K1Means
我还尝试扩展 KMeansModel,因此我可以创建自己的并返回它:
class K1MeansModel(centers: Array[DenseVector]) extends KMeansModel{}
但它也失败了:
constructor KMeansModel in class KMeansModel cannot be accessed in class K1MeansModel
最佳答案
这里有几个问题,首先是 KMeansModel 是私有(private)的:
https://github.com/apache/spark/blob/4f83ca1059a3b580fca3f006974ff5ac4d5212a1/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala#L102
为什么这是个问题?您完全可以按照您建议的方式编写自己的代理,但是为了覆盖“fit”方法,该函数返回的数据类型需要是 KMeansModel 或兼容的(假设为“K1MeansModel”),如下所示:
class K1MeansModel extends KMeansModel{
// ...
}
class K1Means extends KMeans{
final val kmeans = new KMeans()
// ...
override def fit(dataset: DataFrame): KMeansModel = {
if(this.k == 1){
// ...
return new K1MeansModel(centers_local)
}
else{
return this.kmeans.fit(dataset)
}
}
}
但是是的,因为 KMeansModel 是私有(private)的,所以这是不可能的。所以你可能会想“为什么不重新实现它呢?”。实际上,您可以从 GitHub 复制并粘贴 KMeansModel 的整个代码。
KMeansModel 的定义如下所示:
class KMeansModel (
override val uid: String,
private val parentModel: MLlibKMeansModel)
extends Model[KMeansModel] with KMeansParams { }
但是是的,因为 KMeansParams 是私有(private)的,所以这是不可能的。所以你可能会想“为什么不重新实现它呢?”。实际上,您可以从 GitHub 复制并粘贴 KMeansParams 的整个代码。
KMeansParams 的定义如下所示:
trait K1MeansParams
extends Params
with HasMaxIter
with HasFeaturesCol
with HasSeed
with HasPredictionCol
with HasTol { }
但是是的,因为 HasMaxIter、HasFeaturesCol、HasSeed、HasPredictionCol、HasTol 都是私有(private)的,这是不可能的。 ......你明白了。
TL;DR 是的,你可以去重新实现(复制和粘贴)大量的 Spark 类到你的项目中,只是为了覆盖 KMeans。我数了至少 7 个需要复制和粘贴的类。对我来说这感觉很糟糕。 相反,我建议将代码直接添加到 Apache Spark。 Fork Spark GitHub repo ,将 K=1 的代码直接添加到 ml.KMeans 类中并完成它。
关于scala - 如何使用私有(private)构造函数扩展(或代理)Scala 类,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/37941017/