左外链接(leftOuterJoin) spark实现

package com.kangaroo.studio.algorithms.join;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import scala.Tuple2; import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set; public class LeftOuterJoinSpark { private JavaSparkContext jsc;
private String usersInputFile;
private String transactionsInputFile; public LeftOuterJoinSpark(String usersInputFile, String transactionsInputFile) {
this.jsc = new JavaSparkContext();
this.usersInputFile = usersInputFile;
this.transactionsInputFile = transactionsInputFile;
} public void run() {
/*
* 读入users文件, 文件有两列, userId和location, 以制表符\t分割, 形如:
* u1 UT
* u2 GA
* u3 GA
* */
JavaRDD<String> users = jsc.textFile(usersInputFile, 1); /*
* 将字符串切分为kv对
* 输入: line字符串
* 输出: (userId, ("L", location)), 其中L标识这是一个location, 后面会有"P"标识这是一个product
* ("u1", ("L", "UT"))
* ("u2", ("L", "GA"))
* ("u3", ("L", "GA"))
* */
JavaPairRDD<String, Tuple2<String, String>> usersRDD = users.mapToPair(new PairFunction<String, String, Tuple2<String, String>>() {
public Tuple2<String, Tuple2<String, String>> call(String s) throws Exception {
String[] userRecord = s.split("\t");
String userId = userRecord[0];
Tuple2<String, String> location = new Tuple2<String, String>("L", userRecord[1]);
return new Tuple2<String, Tuple2<String, String>>(userId, location);
}
}); /*
* 读入transattion文件, 文件有4列, transactionIdproductId/userId/price, 以制表符\t分割
* t1 p3 u1 300
* t2 p1 u2 400
* t3 p1 u3 200
* */
JavaRDD<String> transactions = jsc.textFile(transactionsInputFile, 1); /*
* 将字符串切分为kv对
* 输入: line字符串
* 输出: (userId, ("P", productId)), "P"标识这是一个product
* ("u1", ("P", "p3"))
* ("u2", ("P", "p1"))
* ("u3", ("P", "p1"))
* */
JavaPairRDD<String, Tuple2<String, String>> transactionsRDD = transactions.mapToPair(new PairFunction<String, String, Tuple2<String, String>>() {
public Tuple2<String, Tuple2<String, String>> call(String s) throws Exception {
String[] transactionRecord = s.split("\t");
String userId = transactionRecord[2];
Tuple2<String, String> product = new Tuple2<String, String>("P", transactionRecord[1]);
return new Tuple2<String, Tuple2<String, String>>(userId, product);
}
}); /*
* 创建users和transaction的一个并集
* 输入:
* transaction ("u1", ("P", "p3"))
* users ("u1", ("L", "UT"))
* 输出:
* (userId, ("L", location))
* (userId, ("P", product))
* */
JavaPairRDD<String, Tuple2<String, String>> allRDD = transactionsRDD.union(usersRDD); /*
* 按照userId进行分组
* 输入:
* (userId, ("L", location))
* (userId, ("P", product))
* 输出:
* (userId, List[
* ("L", location),
* ("P", p1),
* ("P", p2),
* ... ])
* */
JavaPairRDD<String, Iterable<Tuple2<String, String>>> groupedRDD = allRDD.groupByKey(); /*
* 去掉userId, 行程location和product的配对
* 输入:
* (userId, List[
* ("L", location),
* ("P", p1),
* ("P", p2),
* ... ])
* 输出:
* (product1, location1)
* (product1, location2)
* (product2, location1)
* */
JavaPairRDD<String, String> productLocationRDD = groupedRDD.flatMapToPair(new PairFlatMapFunction<Tuple2<String, Iterable<Tuple2<String, String>>>, String, String>() {
public Iterable<Tuple2<String, String>> call(Tuple2<String, Iterable<Tuple2<String, String>>> s) throws Exception {
String userId = s._1;
Iterable<Tuple2<String, String>> pairs = s._2;
String location = "UNKNOWN";
List<String> products = new ArrayList<String>();
for (Tuple2<String, String> t2 : pairs) {
if (t2._1.equals("L")) {
location = t2._2;
} else if (t2._1.equals("P")){
products.add(t2._2);
}
}
List<Tuple2<String, String>> kvList = new ArrayList<Tuple2<String, String>>();
for (String product : products) {
kvList.add(new Tuple2<String, String>(product, location));
}
return kvList;
}
}); /*
* 以productId为key进行分组
* 输入:
* (product1, location1)
* (product1, location2)
* (product2, location1)
* 输出:
* (product1, List[
* location1,
* location2,
* ... ])
* */
JavaPairRDD<String, Iterable<String>> productByLocations = productLocationRDD.groupByKey(); /*
* 对location进行去重
* 输出:
* (product1, List[
* location1,
* location2,
* location2,
* ... ])
* 输出:
* (product1, List[
* location1,
* location2,
* ... ])
* */
JavaPairRDD<String, Tuple2<Set<String>, Integer>> productByUniqueLocations = productByLocations.mapValues(new Function<Iterable<String>, Tuple2<Set<String>, Integer>>() {
public Tuple2<Set<String>, Integer> call(Iterable<String> strings) throws Exception {
Set<String> uniqueLocations = new HashSet<String>();
for (String location : strings) {
uniqueLocations.add(location);
}
return new Tuple2<Set<String>, Integer>(uniqueLocations, uniqueLocations.size());
}
}); /*
* 打印结果
* */
List<Tuple2<String, Tuple2<Set<String>, Integer>>> result = productByUniqueLocations.collect();
for (Tuple2<String, Tuple2<Set<String>, Integer>> t : result) {
// productId
System.out.println(t._1);
// locationSet和size
System.out.println(t._2);
}
} public static void main(String[] args) {
String usersInputFile = args[0];
String transactionsInputFile = args[1];
LeftOuterJoinSpark leftOuterJoinSpark = new LeftOuterJoinSpark(usersInputFile, transactionsInputFile);
leftOuterJoinSpark.run();
}
}
05-08 15:31