问题描述
我正在尝试编写一个返回复杂类型的 UDF:
I'm trying to write an UDF which returns a complex type:
private val toPrice = UDF1<String, Map<String, String>> { s ->
val elements = s.split(" ")
mapOf("value" to elements[0], "currency" to elements[1])
}
val type = DataTypes.createStructType(listOf(
DataTypes.createStructField("value", DataTypes.StringType, false),
DataTypes.createStructField("currency", DataTypes.StringType, false)))
df.sqlContext().udf().register("toPrice", toPrice, type)
但任何时候我使用它:
df = df.withColumn("price", callUDF("toPrice", col("price")))
我收到一个神秘错误:
Caused by: org.apache.spark.SparkException: Failed to execute user defined function($anonfun$28: (string) => struct<value:string,currency:string>)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
at org.apache.spark.scheduler.Task.run(Task.scala:109)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Caused by: scala.MatchError: {value=138.0, currency=USD} (of class java.util.LinkedHashMap)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:236)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:231)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$CatalystTypeConverter.toCatalyst(CatalystTypeConverters.scala:103)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$$anonfun$createToCatalystConverter$2.apply(CatalystTypeConverters.scala:379)
... 19 more
我尝试使用自定义数据类型:
I tried to use a custom data type:
class Price(val value: Double, val currency: String) : Serializable
带有返回该类型的 UDF:
with an UDF which returns that type:
private val toPrice = UDF1<String, Price> { s ->
val elements = s.split(" ")
Price(elements[0].toDouble(), elements[1])
}
但后来我得到另一个 MatchError
抱怨 Price
类型.
but then I get another MatchError
which complains for the Price
type.
如何正确编写可以返回复杂类型的 UDF?
推荐答案
或任何混合变体,如
提供此变体主要是为了确保 Java 互操作性.
This variant is provided primarily to ensure Java interoperability.
在这种情况下(相当于所讨论的那个),定义应该与以下类似:
In this case (equivalent to the one in question) the definition should be similar to the following one:
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.Row
val schema = StructType(Seq(
StructField("value", DoubleType, false),
StructField("currency", StringType, false)
))
val toPrice = udf((s: String) => scala.util.Try {
s split(" ") match {
case Array(price, currency) => Row(price.toDouble, currency)
}
}.getOrElse(null), schema)
df.select(toPrice($"price")).show
// +----------+
// |UDF(price)|
// +----------+
// |[1.0, USD]|
// | null|
// +----------+
排除异常处理的所有细微差别(通常UDFs
应该控制null
输入并按照惯例优雅地处理格式错误的数据)Java 等价物应该或多或少像这个:
Excluding all the nuances of exception handling (in general UDFs
should contr ol for null
input and by convention gracefully handle malformed data) Java equivalent should look more or less like this:
UserDefinedFunction price = udf((String s) -> {
String[] split = s.split(" ");
return RowFactory.create(Double.parseDouble(split[0]), split[1]);
}, DataTypes.createStructType(new StructField[]{
DataTypes.createStructField("value", DataTypes.DoubleType, true),
DataTypes.createStructField("currency", DataTypes.StringType, true)
}));
上下文:
为了给您一些上下文,这种区别也反映在 API 的其他部分中.例如,您可以从架构和一系列 Rows
中创建 DataFrame
:
To give you some context this distinction is reflected in the other parts of the API as well. For example, you can create DataFrame
from a schema and a sequence of Rows
:
def createDataFrame(rows: List[Row], schema: StructType): DataFrame
或对一系列 Products
def createDataFrame[A <: Product](data: Seq[A])(implicit arg0: TypeTag[A]): DataFrame
但不支持混合变体.
换句话说,您应该提供可以使用 RowEncoder
进行编码的输入.
In other words you should provide input that can be encoded using RowEncoder
.
当然,您通常不会将 udf
用于这样的任务:
Of course you wouldn't normally use udf
for the task like this one:
import org.apache.spark.sql.functions._
df.withColumn("price", struct(
split($"price", " ")(0).cast("double").alias("price"),
split($"price", " ")(1).alias("currency")
))
相关:
这篇关于如何在 Java/Kotlin 中创建一个返回复杂类型的 Spark UDF?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!