陈天奇:兔子和分布式机器学习

作者:陈天奇

来源:http://weibo.com/p/1001603801281637563132?pids=Pl_Official_CardMixFeed__5&feed_filter=2

小编:作者从实际项目需求出发,参考MPI抽离设计出兔子(Rabit:可容错的Allreduce),并在Xgboost中成功应用。本文以故事的方式介绍了Rabit设计的背景和思路,能看出每个阶段“需求”下对应的“方案”,再进一步到“实现”。总之:一级棒!!!

上个学习的时候,我除了TA机器学习以外,另外一半的时间就是上了System课程。因为上课的缘故,需要做一个课程项目,于是我决定做一些和分布式机器学习相关的事情。

来到UW之后每一个和engineering的课程设计都会想做一些高效实用的东西。这一次一开始想到的目标之一,就是写一个自己可以写出的最高效的分布式boostedtree(也就是大家常说的GBDT)工具。想这么做也并非没有基础,之前我们完成的xgboost,已经是很快的单机多线程版本了,可以轻松处理百万到千万级别的数据。

不过一个显然的问题是,分布式机器学习程序并非systems而仅仅算是系统的应用而已。并不足以作为一个系统课程的课程设计来做。因此我到了直到学期结束前三周依然没有明确题目,只是在空闲的时间尝试去实现分布式的tree。

目标:

在一开始设计这个东西的时候,我有一些心中的目标。一开始的目标基本总结下来就是,速度快,可移植,少写代码。 速度快是自然的目标,体现这个目标的一个的想法是在写的似乎机器内部依然沿用单机多线程的优化来减少通信,只在机器之间来进行通信。

可移植性是一个更大的困难,要做分布式机器学习必须有分布式的通信框架。而每个分布式系统本身的抽象各不相同,hadoop/spark做的是MapReduceabstraction,graphlab做graph parallel,MPI提供的是 Allreduce/Broadcast,PS提供的是异步的更新。要想要让分布式程序可以运行在不同的环境下比起让一个单机机器学习程序从linux移植到windows困难许多。但是其实本质上,这个东西和把linux移植到windowss又没有差别。之所以一个程序可以从linux移植到windows,是因为程序仅仅依赖一些系统调用的接口,而我们只需要在每个平台下面提供一套这类接口就可以了。同样的,我也在之前就表示过,我不想让自己的机器学习代码往一个特殊的平台上面去靠,而是希望根据算法本身的需求,抽象出合理的接口,通过通用的库让平台往接口需求上面去走。原因也很简单,比起很多平台一开始支持的比较好的数据处理问题,机器学习往往需要消耗更多的计算资源和时间,根据机器学习的需求去设计通信库也是很自然的事情。

因此我在设计的时候定的目标是想一想需要什么,大不了自己写通信,不要依赖平台。。经过这样思考之后我考虑了对于boostedtree可能的分布式算法。发现Allreduce的交互比较自然,于是决定使用Allreduce作为算法依赖的通信。

为什么用Allreduce:

决定使用Allreduce并非偶然。之前我在weibo上面提分布式机器学习需要依赖的基本通信抽象的时候,老师木就说过他觉得应该是MPI。大部分的分布式机器学习算法的结构都是分布数据,在每个子集上面算出一些局部的统计量,然后整合出全局的统计量,并且在分配给各个计算节点去进行下一轮的迭代。这样一个过程就是Allreduce。分布式机器学习和传统的数据处理有一些区别。其中一个大的区别是分布式机器学习算法往往会用比较多的资源,包括临时空间,线程,结果缓存等。这样机器学习程序往往显得更加“重量级”。因此为了优化这一特点,我们往往需要让一个程序在必要的时候占领一台机器,并且在所有迭代的时候一直跑到底,来防止重新分配资源的开销。一开始为数据处理而设计的MapReduce采用的是多stage执行的方式,没有具备这一特点。而 Allreduce和基于异步通信的Parameter Server框架则具备了这一特点,因此更加有利于高效的机器学习算法。我个人认为同步的Allreduce和异步的PS抽象会是高效机器学习算法最常用的两个抽象,越来越多地出现在以机器学习为中心的分布式平台中。

实现一个 Allreduce:

我实现的分布式xgboost第一个版本,是基于Allreduce接口的。而并非直接依赖于MPI。原因也很简单, 我希望算法依赖仅仅需要的最少的接口,来方便后面最大的可移植性。在第一个版本中Allreduce接口的实现本身直接采用了MPI的实现。为了实现更大的可移植性,我又到后来自己手码了一个Allreduce的库,使得程序运行不再依赖于MPIcluster。实现到这里,有一个比较有趣的地方是我发现可移植和少写代码这两个目标其实出奇地一致。因为采用了简单的接口,分布式的代码对于多线程代码的不过仅仅是加入了一些同步函数而已,而这样同步函数也可以直接关掉,让分布式程序和单机程序共享一份代码。

Rabit: 可容错的Allreduce

当我搞完第一个版本的分布式的时候,我发现离学期末只有两周时间了。。。当时一下子拿不定主意想要做什么,因为分布式机器学习程序有了,但是不够交作业。但是想要改进boostedtree的愿望促使我开始想另外一个常见需求:容错。

Allreduce是MPI提供的一个主要功能,但是MPI一般不是特别受到广泛欢迎,原因之一就是它本身不容错。 经过考虑,我们发现其实分布式机器学习程序里面只需要Allreduce,这一点JohnLangford也在他的博客里面经常提到。而更加重要的一点是,如果砍掉MPI多余的接口,就保留Allreduce和Broadcast,支持容错会变得容易许多。原因是Allreduce有一个很好的性质,每一个节点最后拿到的是一样的结果,这意味着我们可以让一些节点记住结果。当有节点挂掉重启的时候,可以直接向还活着的节点索要结果就可以了。

基于这个基本的想法,和机器学习的需求,我们设计了一个可以容错的Allreduce库。叫做Rabit(ReliableAllreduce and Broadcast Interface)。然后这个库了就成为我最后系统作业交差的内容。目前的Rabit支持python和C++,并且可以运行在包括MPI和Hadoop 等各种平台上面。基于rabit写的程序也可以自然地移植到各个平台上。值得一提的是因为通信是Allreduce,那个原来一直被当作分布式机器学习龟速baseline的Hadoop也可以跑高效的分布式机器学习程序了。

而Rabit的设计理念也和兔子的特性差不多:到处跑(可移植),可以挖坑不怕死(容错),跑得快。并且轻量级。不做解决一切问题的框架,但是做可移植,精确解决一部分机器学习问题的高效容错通信库。

最后是代码链接,欢迎,使用拍砖以及贡献代码:)

Rabit库 https://github.com/tqchen/rabit

Rabit的教程:https://github.com/tqchen/rabit/tree/master/guide

分布式的boosted tree(GBDT): https://github.com/tqchen/xgboost/tree/master/multi-node

Tags :

1 Comment

留下你的评论