我正在尝试扩展或代理 org.apache.spark.ml.clustering.KMeans 类,以便授权 K=1。

class K1Means extends Estimator{

    final val kmeans = new KMeans()
    val k = 1

    override def setK(value:Int) {
        if(value >1){
            this.kmeans.setK(value)
        }
    }

    override def fit(dataset: DataFrame): KMeansModel = {
        if(this.k == 1){
            /* super specific to my case */
            val avg_sample = Vectors.zeros(
                dataset
                .select("scaledFeatures")
                .take(1)(0)(0)  // first row
                .asInstanceOf[DenseVector]  // was of type Any
                .size
            ) // with the scaling the average value of each column is 0
            var centers_local = Array(avg_sample)
            return new KMeansModel(centers_local)
        }
        else{
            return this.kmeans.fit(dataset)
        }
    }
// every method then calls this.kmeans.method()
}

我试过这个,但是 new KMeansModel(centers_local) 没有被授权,因为 KMeansModel 有一个私有(private)构造函数。
这是错误消息:
constructor KMeansModel in class KMeansModel cannot be accessed in class K1Means
我还尝试扩展 KMeansModel,因此我可以创建自己的并返回它:
class K1MeansModel(centers: Array[DenseVector]) extends KMeansModel{}

但它也失败了:constructor KMeansModel in class KMeansModel cannot be accessed in class K1MeansModel

最佳答案

这里有几个问题,首先是 KMeansModel 是私有(private)的:
https://github.com/apache/spark/blob/4f83ca1059a3b580fca3f006974ff5ac4d5212a1/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala#L102

为什么这是个问题?您完全可以按照您建议的方式编写自己的代理,但是为了覆盖“fit”方法,该函数返回的数据类型需要是 KMeansModel 或兼容的(假设为“K1MeansModel”),如下所示:

class K1MeansModel extends KMeansModel{
    // ...
}

class K1Means extends KMeans{

    final val kmeans = new KMeans()
    // ...

    override def fit(dataset: DataFrame): KMeansModel = {
        if(this.k == 1){
            // ...
            return new K1MeansModel(centers_local)
        }
        else{
            return this.kmeans.fit(dataset)
        }
    }
}

但是是的,因为 KMeansModel 是私有(private)的,所以这是不可能的。所以你可能会想“为什么不重新实现它呢?”。实际上,您可以从 GitHub 复制并粘贴 KMeansModel 的整个代码。

KMeansModel 的定义如下所示:
class KMeansModel (
        override val uid: String,
        private val parentModel: MLlibKMeansModel)
    extends Model[KMeansModel] with KMeansParams { }

但是是的,因为 KMeansParams 是私有(private)的,所以这是不可能的。所以你可能会想“为什么不重新实现它呢?”。实际上,您可以从 GitHub 复制并粘贴 KMeansParams 的整个代码。

KMeansParams 的定义如下所示:
trait K1MeansParams
    extends Params
        with HasMaxIter
        with HasFeaturesCol
        with HasSeed
        with HasPredictionCol
        with HasTol { }

但是是的,因为 HasMaxIter、HasFeaturesCol、HasSeed、HasPredictionCol、HasTol 都是私有(private)的,这是不可能的。 ......你明白了。

TL;DR 是的,你可以去重新实现(复制和粘贴)大量的 Spark 类到你的项目中,只是为了覆盖 KMeans。我数了至少 7 个需要复制和粘贴的类。对我来说这感觉很糟糕。 相反,我建议将代码直接添加到 Apache Spark。 Fork Spark GitHub repo ,将 K=1 的代码直接添加到 ml.KMeans 类中并完成它。

关于scala - 如何使用私有(private)构造函数扩展(或代理)Scala 类,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/37941017/

10-12 22:57