当我们的设备具有多个GPUs时,为了训练加速,我们通常会选用多卡并行训练,常见的并行训练方式有数据并行和模型并行。而PyTorch中也给我们提供了数据并行的接口DataParalle。本文将对该并行过程做一个简单的总结。
1. 数据并行示意图
(a) single GPU (b) parallel model (c) parallel citerion
2. DataParallel过程分析
本文将通过分析PyTorch中DataParallel源码的方式对并行过程展开讨论。Note:本文中只展示核心代码
1 | class DataParallel(Module): |
数据并行的主要过程就是包括上面4个步骤,结合图1(b)来看更加直观。
3. DataParallelModel和DataParallelCriterion拓展
在PyTorch官方提供的DataParallel接口中,是从输入x到输出y的过程,但是,在我们通常的训练过程中,我们还需要计算输入y与label之间的损失Loss,所以,为了加速到底,最好是将loss的计算也在过个GPU上进行。DataParallelCriterion的过程主要参考了PyTorch-Encoding中的代码,代码质量很高,强烈推荐一波:+1:
在DataParallelCriterion之前,我们首先需要对原始的DataParallel进行修改,即返回结果时不让其执行gather聚集过程,因为我们接下来还要在每个分解batch上继续计算loss。
1 | class DataParallelModel(DataParallel): |
为了并行计算loss,主要的不同是需要将batch label同样划分为多个子batch。如下:
1 | class DataParallelCriterion(DataParallel): |
整个并行训练过程如上所述,可结合图1(c)理解。
3. 并行训练中的注意事项
- 所有的划分都是在batch维度进行
- batch size必须大于GPUs的数量,最好是保证是其数量的整数倍