Databricks孟祥瑞:ALS 在 Spark MLlib 中的实现
深受用户喜爱的大数据处理平台 Apache Spark 1.3 于前不久发布,MLlib 作为 Spark 负责机器学习 (ML) 的核心组件在 1.3 中添加了不少机器学习及数据挖掘的算法:研究主题分布的 latent Dirichlet allocation (LDA)、估计点集分布的高斯混合模型 (GMM)、提取频繁项集的 FP-growth、生成图聚类的 power iteration clustering (PIC)等等。呃,这些我们暂放一边不谈。MLlib 还添加了 Python 的 ML 流水线接口、模型基于 Parquet 的存储、以及分布式分块矩阵模型。呃,这些我们暂放另一边,也不谈……
那我们谈些什么?我想借这个机会聊聊 ALS 算法和其在 MLlib 中的实现,特别是在 Spark 1.3 中的改进。希望可以起到抛砖引玉的作用,让更多的人关注在 Spark 上实现机器学习算法会遇到的算法重构和运行效率问题。
ALS 是什么?
ALS 是交替最小二乘 (alternating least squares)的简称。在机器学习的上下文中,ALS
特指使用交替最小二乘求解的一个协同推荐算法。它通过观察到的所有用户给产品的打分,来推断每个用户的喜好并向用户推荐适合的产品。举个例子,我们考虑下面这个包含用户打分的打分矩阵:
这个矩阵的每一行代表一个用户 (u1,u2,...,u9)、每一列代表一个产品 (v1,v2,…,v9)。用户的打分在 1-9
之间。我们只显示观察到的打分。那么问题来了:用户
u5 给产品 v4 的打分大概会是多少?粗略地观察一下……这不是数独么?是的,而且如果按照数独来做的话(比较耗时、不推荐),用户
u5 一定会给产品
v4 打 9
分。为什么看上去选择很多,答案却是唯一的?因为数独的规则很强,每添加一条规则,就让整个系统的自由度下降一个量级。当我们要满足所有的规则时,整个系统的自由度已然降为一了。现在请努力地把上面的数独题想成一个打分矩阵。如果我们不添加任何条件的话,打分之间是相互独立的,我们没有任何依据来推断
u5 给 v4
的打分。所以在这个打分矩阵的基础上,我们需要提出一个限制其自由度的合理假设,使得我们可以通过观察已有打分猜测未知打分。
ALS 的核心就是下面这个假设:打分矩阵是近似低秩的。换句话说,一个
的打分矩阵 A 可以用两个小矩阵和的乘积来近似:。这样我们就把整个系统的自由度从一下降到了。当然,我们也可以随便提一个假设把自由度直接降到一。我们接下来就聊聊为什么
ALS
的低秩假设是合理的。世上万千事物,人们的喜好各不相同。但描述一个人的喜好经常是在一个抽象的低维空间上进行的,并不需要把其喜欢的事物一一列出。举个例子,我喜欢看略带黑色幽默的警匪电影,那么大家根据这个描述就知道我大概会喜欢昆汀的《低俗小说》、《落水狗》和韦家辉的《一个字头的诞生》。这些电影都符合我对自己喜好的描述,也就是说他们在这个抽象的低维空间的投影和我的喜好相似。再抽象一些,把人们的喜好和电影的特征都投到这个低维空间,一个人的喜好映射到了一个低维向量,一个电影的特征变成了纬度相同的向量,那么这个人和这个电影的相似度就可以表述成这两个向量之间的内积。 我们把打分理解成相似度,那么打分矩阵A就可以由用户喜好矩阵和产品特征矩阵的乘积来近似了。
我们大致解释了
ALS 低秩假设的合理性,接下来的问题是怎么选这个抽象的低维空间。这个低维空间要能够有效的区分事物,如果我说我喜欢看 16:9
宽屏的彩色立体声电影,那一定是我真心不想透露我的喜好。但 ALS
是很难从实质上理解“黑色幽默”和“彩色”的区别是什么的,它需要一个更明确的可以量化的目标,这就是重构误差。既然我们的假设是打分矩阵A可以通过来近似,那么一个最直接的可以量化的目标就是通过U,V重构A所产生的误差。在 ALS 里,我们使用 Frobenius范数,,来量化重构误差,就是每个元素的重构误差的平方和。这里存在一个问题,我们只观察到部分打分,A 中的大量未知元正是我们想推断的,所以这个重构误差是包含未知数的。解决方案很简单很暴力:就只看对已知打分的重构误差吧。所以 ALS 的优化目标是:。这里 R 指观察到的 (用户,产品)集。
我们把一个协同推荐的问题通过低秩假设成功转变成了一个优化问题。下面要讨论的内容很显然:这个优化问题怎么解?其实答案已经在
ALS 的名字里给出——交替最小二乘。ALS
的目标函数不是凸的,而且变量互相耦合在一起,所以它并不算好解。但如果我们把用户特征矩阵U和产品特征矩阵V固定其一,这个问题立刻变成了一个凸的而且可拆分的问题。比如我们固定U,那么目标函数就可以写成。其中关于每个产品特征的部分是独立的,也就是说固定U求我们只需要最小化就好了,这个问题就是经典的最小二乘问题。所谓“交替”,就是指我们先随机生成然后固定它求解,再固定求解,这样交替进行下去。因为每步迭代都会降低重构误差,并且误差是有下界的,所以 ALS 一定会收敛。但由于问题是非凸的,ALS 并不保证会收敛到全局最优解。但在实际应用中,ALS 对初始点不是很敏感,是不是全局最优解造成的影响并不大。
ALS 在 MLlib 中的实现
ALS 的算法介绍完了,但我们距离一个好的分布式实现还有一段距离。因为 ALS 每步迭代中优化问题的目标函数可以拆分成互相独立的最小二乘子问题,所以从计算的角度来看 ALS 是适合分布式求解的。但通过观察一个子问题,我们会发现求解 vj是需要知道上一步得到的每个已知打分对应的的值。如果分布式求解,我们可能会需要从其它节点获取这些数据,从而产生通信费用。和很多机器学习算法的分布实现类似,ALS 的分布式实现主要关心的是计算复杂度和通信复杂度。
计算复杂度比较容易估算,所以我们先讲。求解一个的最小二乘问题的复杂度是。当固定U求V时,我们一共有n个最小二乘子问题,所以总的复杂度是,其中 nnz 指观察到的打分数量。再加上固定V求U的复杂度,一步完整的迭代需要的计算量就是。MLlib 中的 ALS 实现通过法方程 (normal equation) 求解最小二乘子问题,需要的空间复杂度是。最小二乘有很多种求解方法,这里为什么选法方程以及其求解精度我们就略去不谈了。
通信复杂度是分布式实现一个算法时一定要重点考虑的问题,稍有不慎就会导致十倍甚至百倍的效率损失。我们先看一下最坏的情况:假设求解时所需要的用户特征都需要从其它节点获取,并且子问题之间完全独立。例如图1所示,求解 v1 需要获取 u1 和 u2,求解 v2 需要获取 u1、u2 和 u3等等。这种假设下每步迭代需要交换的数据量是,比输入数据要高一个量级。虽然还是比每步迭代需要的计算量低一个量级,但由于k一般不大,而且做一个浮点运算比通过网络传输一个字节要快很多,所以在这种情况下通信时间会远远超出计算时间。
图1:通信复杂度示例图
为了在
Spark 上提供一个高效的 ALS 实现,我们需要合理的设计数据分区和 RDD 缓存来减少数据交换。从上面的图我们会观察到,如果计算 v1 和
v2 是在同一个分区上完成的,我们只需要把 u1 和 u2 一次发给这个分区,然后在计算 v1 和 v2 的时时候在本机内存直接读取 u1 和
u2 即可。 这样就省掉了不必要的数据传输。图2描述了如何在分区的情况下通过
U来求解V,注意节点之间的数据交换量减少了。使用这种分区结构,我们需要在原始打分数据的基础上额外保存一些信息。在 P1,我们要知道把 u1 发给
Q1 和 Q2,把 u2 发给 Q1。我们可以查看和 u1 相关联的所有产品来确定需要把 u1
发给谁,但每次迭代都扫一遍数据是很不划算的。所以在 MLlib 的实现中我们只计算一次这个信息,然后把结果通过 RDD
缓存起来重复使用。这部分数据我们在代码里称作 OutBlock。在 Q1,我们需要知道 v1
和哪些用户向量有关联及其对应的打分,从而构建最小二乘问题并求解。这部分数据不仅包含原始打分数据,还包含从每个用户分区收到的向量排序信息,我们在代码里称作
InBlock。所以从 U 求解 V,我们需要通过用户的 OutBlock 信息把用户向量发给产品分区,然后通过产品的 InBlock
信息构建最小二乘问题并求解。从 V 求解 U,我们需要产品的 OutBlock 信息和用户的 InBlock 信息。所有的 InBlock 和
OutBlock 信息在迭代过程中都通过 RDD 缓存。大家会发现原始的打分数据其实在用户的 InBlock 和产品的 InBlock
各存了一份,但分区方式不同,这么做可以避免在迭代过程中对原始数据的交换。
图2:数据分区设计后的通信复杂度
接下来我们讨论一下
InBlock 的数据结构。以 Q1 为例,我们要知道所有关于 v1 和 v2 的所有打分:(v1, u1, a11),(v2, u1,
a12), (v1, u2, a21), (v2, u2, a22), (v2, u3, a32)。但是把这些打分直接按照 Tuple
存的话会有几个问题。首先是空间的额外开销,每个 Tuple 实例都需要一个指针,而每个 Tuple 所存的数据不过是两个 ID
和一个打分,非常不划算。而且存储大量的 Tuple 实例会降低 Java 垃圾回收效率。所以我们使用三个原始数组来存 InBlock
信息:([v1, v2, v1, v2, v2], [u1, u1, u2, u2, u3], [a11, a12, a21, a22,
a32])。这样不仅大幅减少了实例数量,还有效地利用了连续内存。但还存在一个问题,当我们求解 v1 时,我们要通过所有和 v1 关联的用户向量
(u1, u2) 来构建最小二乘问题。这里有两个选择:a) 扫一遍 InBlock 信息,同时对所有的产品构建对应的最小二乘问题;b)
对于每一个产品,扫描 InBlock 信息,构建并求解其对应的最小二乘问题。之前提到过通过法方程求解一个最小二乘问题的空间复杂度是,所以方法 a 所需要的空间是,比存储产品向量所需空间高出一个量级。而方法
b 也不算理想,因为要对 InBlock 信息多次扫描。在Spark 1.3 里,我们首先将 InBlock 信息按照产品 ID 排序:
([v1, v1, v2, v2, v2],[u1, u2, u1, u2, u3], [a11, a21, a12, a22,
a32])。这样我们只需要顺序扫描一遍数据,就可以逐个创建最小二乘问题并求解,这样所需的空间降到了。在
Java 里将三个很大的原始数组根据某一个排序并不是件很容易的事情。我们使用 Spark 中的 TimSort 实现来排序,这也是在
Petabyte Sort 比赛中 Databricks
小组所使用的排序算法。排序后的另外一个好处是我们可以把数据进一步压缩。对于每一个产品,我们只需纪录它所对应的打分开始和结束的位置即可。InBlock
就变成了这样:([v1, v2], [0, 2, 5], [u1, u2, u1, u2, u3], [a11, a21, a12, a22,
a32])。其中 [0, 2] 指 v1 对应的打分的区间是 [0, 2),[2, 5] 指 v2 对应的打分的区间是 [2,
5)。通过一系列的调整,我们在内存使用、时间和空间复杂度上都达到了较好的效果。
在 Spark 1.3 中,我们还对 ALS
做了一些其它的改进。为了避免不必要的 map 查询和支持多种 ID 类型,我们在实现中并没有直接在 InBlock 中存储用户的原始
ID,而只记录了需要的用户向量应该是哪个分区发过来的第几个。比如在 Q1 分区 ,u2 就是从 P1 发过来的第二个,而 u2 原始的 ID
是多少并不影响问题的求解。我们把分区和索引信息编码到一个整型里,在高位存分区 ID,在低位存对应分区的索引,在空间上也尽量做到不浪费。此外,因为
ALS 对求解的精度要求不高,为了减少数据交换量,我们把Spark 1.2 中使用的 Double 改成了 Float
来存储用户和产品向量。还有一些优化我们就不一一提及了,有兴趣的读者可以参看 ALS 源码以及相关的 JIRA。
通过对实现的改进,新版的
ALS 在速度、资源和稳定性上都有大幅度提升。下图是我们在 Amazon Reviews 数据集上做的一些比较。测试使用 16 个
m3.2xlarge 节点的 Amazon EC2 集群。可以看到,ALS 在速度上对比 Spark 1.2 有 2-4x
的提升,而且表现出了更好的伸缩性。我们还在更大的集群上测试了一个大概有 500 亿打分的数据集,ALS 表示无压力。
小结
本文简单介绍了
ALS 算法和其在 MLlib 中的实现。希望通过分析 ALS
可以让大家直观的看到,同样的算法,在分布式系统上实现时,不同的选择会带来性能上巨大的差异。大家在 Spark
上实现机器学习算法时,不妨先分析一下空间、时间、和通信复杂度,然后合理的利用 Spark 的分区和缓存机制做到高效的实现。希望在 2015
年看到更多的人加入 MLlib 的开发和维护,让 MLlib 的算法更好更快更易用!
孟祥瑞,Databricks 软件工程师、Apache Spark PMC成员 ,Apache Spark Committer。