Model Averaging in Distributed Machine Learning: A Case Study with Apache Spark
- Yunyan Guo ,
- Zhipeng Zhang ,
- Jiawei Jiang ,
- Wentao Wu ,
- Ce Zhang ,
- Bin Cui ,
- Jianzhong Li
VLDB Journal |
The increasing popularity of Apache Spark has attracted many users to put their data into
its ecosystem. On the other hand, it has been witnessed in the literature that Spark is slow when
it comes to distributed machine learning (ML). One resort is to switch to specialized systems such as
parameter servers, which are claimed to have better performance. Nonetheless, users have to undergo the
painful procedure of moving data into and out of Spark. In this paper, we investigate performance
bottlenecks of MLlib (an official Spark package for ML) in detail, by focusing on analyzing its implementation
of Stochastic Gradient Descent (SGD) — the workhorse under the training of many ML models. We show
that the performance inferiority of Spark is caused by implementation issues rather than fundamental flaws
of the Bulk Synchronous Parallel (BSP) model that governs Spark’s execution: We can significantly improve
Spark’s performance by leveraging the well-known “model averaging” (MA) technique in distributed ML.
Indeed, model averaging is not limited to SGD, and we further showcase an application of MA to training
latent Dirichlet allocation (LDA) models within Spark. Our implementation is not intrusive and requires light
development effort. Experimental evaluation results reveal that the MA-based versions of SGD and LDA
can be orders of magnitude faster compared to their counterparts without using MA.