UDAF:用户自定义聚合函数

Scala 2.10.7,spark 2.0.0

package UDF_UDAF

import java.util

import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, RowFactory, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType} class UDAF extends UserDefinedAggregateFunction {
/**
* 指定输入字段的字段及类型
*/
override def inputSchema: StructType =
DataTypes.createStructType(Array(DataTypes.createStructField("namexxx",DataTypes.StringType,true))) /**
* 在进行聚合操作的时候所要处理的数据的结果的类型
* */
override def bufferSchema: StructType =
DataTypes.createStructType(Array(DataTypes.createStructField("buffer",DataTypes.IntegerType,true))) /**
* 指定UDAF计算后返回的结果类型
* @return
*/
override def dataType: DataType = DataTypes.IntegerType /**
* 确保一致性 一般用true,用以标记针对给定的一组输入,UDAF是否总是生成相同的结果。
*/
override def deterministic: Boolean = true /**
* 初始化一个内部的自己定义的值,在Aggregate之前每组数据的初始化结果
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = buffer.update(0,0) /**
* 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑
* buffer.getInt(0)获取的是上一次聚合后的值
* 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合
* 大聚和发生在reduce端.
* 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = buffer.update(0, buffer.getInt(0)+1) /**
* 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理
* 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
* buffer1.getInt(0) : 大聚合的时候 上一次聚合后的值
* buffer2.getInt(0) : 这次计算传入进来的update的结果
* 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = buffer1.update(0, buffer1.getInt(0)+buffer2.getInt(0)) /**
* 最后返回一个和dataType方法的类型要一致的类型,返回UDAF最后的计算结果
*/
override def evaluate(buffer: Row): Any = buffer.getInt(0)
} object UDAF{
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setMaster("local").setAppName("udaf")
val sparkSession = SparkSession.builder().config(conf).config("spark.sql.warehouse.dir","/test/warehouse").getOrCreate()
val sc = sparkSession.sparkContext val parallelize = sc.parallelize(Array("zhangsan","lisi","wanger","zhaosi","zhangsan","lisi"))
val rowRDD = parallelize.map(s=>RowFactory.create(s)) val fields = new util.ArrayList[StructField]()
fields.add(DataTypes.createStructField("name",DataTypes.StringType,true))
val schema = DataTypes.createStructType(fields) val df = sparkSession.createDataFrame(rowRDD, schema)
df.createOrReplaceTempView("user") sparkSession.udf.register("StringCount",new UDAF()) sparkSession.sql("select name, StringCount(name) as StrCount from user group by name").show() sparkSession.stop() }
}

Spark SQL UDAF示例-LMLPHP

04-19 14:18