经过一段时间的摸索,在Spark 2.3中,我终于能够保存一个纯python自定义转换器。但是在重新加载变压器时出现错误。
我检查了保存内容,并找到了保存在HDFS上文件中的所有相关变量。如果有人可以发现我在这个简单的转换器中缺少的工作,那就太好了。
from pyspark.ml import Transformer
from pyspark.ml.param.shared import Param,Params,TypeConverters
class AggregateTransformer(Transformer,DefaultParamsWritable,DefaultParamsReadable):
aggCols = Param(Params._dummy(), "aggCols", "",TypeConverters.toListString)
valCols = Param(Params._dummy(), "valCols", "",TypeConverters.toListString)
def __init__(self,aggCols,valCols):
super(AggregateTransformer, self).__init__()
self._setDefault(aggCols=[''])
self._set(aggCols = aggCols)
self._setDefault(valCols=[''])
self._set(valCols = valCols)
def getAggCols(self):
return self.getOrDefault(self.aggCols)
def setAggCols(self, aggCols):
self._set(aggCols=aggCols)
def getValCols(self):
return self.getOrDefault(self.valCols)
def setValCols(self, valCols):
self._set(valCols=valCols)
def _transform(self, dataset):
aggFuncs = []
for valCol in self.getValCols():
aggFuncs.append(F.sum(valCol).alias("sum_"+valCol))
aggFuncs.append(F.min(valCol).alias("min_"+valCol))
aggFuncs.append(F.max(valCol).alias("max_"+valCol))
aggFuncs.append(F.count(valCol).alias("cnt_"+valCol))
aggFuncs.append(F.avg(valCol).alias("avg_"+valCol))
aggFuncs.append(F.stddev(valCol).alias("stddev_"+valCol))
dataset = dataset.groupBy(self.getAggCols()).agg(*aggFuncs)
return dataset
保存它后,在加载此转换器的实例时收到此错误。
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-172-44e20f7e3842> in <module>()
----> 1 x = agg.load("/tmp/test")
/usr/hdp/current/spark2.3-client/python/pyspark/ml/util.py in load(cls, path)
309 def load(cls, path):
310 """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
--> 311 return cls.read().load(path)
312
313
/usr/hdp/current/spark2.3-client/python/pyspark/ml/util.py in load(self, path)
482 metadata = DefaultParamsReader.loadMetadata(path, self.sc)
483 py_type = DefaultParamsReader.__get_class(metadata['class'])
--> 484 instance = py_type()
485 instance._resetUid(metadata['uid'])
486 DefaultParamsReader.getAndSetParams(instance, metadata)
TypeError: __init__() missing 2 required positional arguments: 'aggCols' and 'valCols'
最佳答案
找出答案!
问题是读者正在初始化一个新的Transformer类,但是我的AggregateTransformer的init函数没有参数的默认值。
因此,更改以下代码行即可解决此问题!
def __init__(self,aggCols=[],valCols=[]):
因为我很难找到可以保存并在任何地方读回的纯Python转换器的工作示例,所以在这里将这个答案和问题留在这里是非常困难的!它可以帮助寻找此内容的人。
关于apache-spark - 阅读自定义的pyspark转换器,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/52443326/