并行训练-PyTorch

当我们的设备具有多个GPUs时,为了训练加速,我们通常会选用多卡并行训练,常见的并行训练方式有数据并行和模型并行。而PyTorch中也给我们提供了数据并行的接口DataParalle。本文将对该并行过程做一个简单的总结。

1. 数据并行示意图

img

img

img

(a) single GPU (b) parallel model (c) parallel citerion

2. DataParallel过程分析

本文将通过分析PyTorch中DataParallel源码的方式对并行过程展开讨论。Note:本文中只展示核心代码

1
2
3
4
5
6
7
8
9
10
11
12
13
class DataParallel(Module):
def __init__(self, module, device_ids=None, output_device=None, dim=0):
super(DataParallel, self).__init__()

def forward(self, *inputs, **kwargs):
# 1.将输入batch data分解
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
# 2.将model在每个GPU上各复制一份
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
# 3.进行并行前向计算
outputs = self.parallel_apply(replicas, inputs, kwargs)
# 4.将分解后的计算结果重新聚集,合成为原始的batch大小
return self.gather(outputs, self.output_device)

数据并行的主要过程就是包括上面4个步骤,结合图1(b)来看更加直观。

3. DataParallelModel和DataParallelCriterion拓展

在PyTorch官方提供的DataParallel接口中,是从输入x到输出y的过程,但是,在我们通常的训练过程中,我们还需要计算输入y与label之间的损失Loss,所以,为了加速到底,最好是将loss的计算也在过个GPU上进行。DataParallelCriterion的过程主要参考了PyTorch-Encoding中的代码,代码质量很高,强烈推荐一波:+1:

在DataParallelCriterion之前,我们首先需要对原始的DataParallel进行修改,即返回结果时不让其执行gather聚集过程,因为我们接下来还要在每个分解batch上继续计算loss。

1
2
3
4
class DataParallelModel(DataParallel):
# 修改gather过程,使其直接返回
def gather(self, outputs, output_device):
return outputs

为了并行计算loss,主要的不同是需要将batch label同样划分为多个子batch。如下:

1
2
3
4
5
6
7
8
9
class DataParallelCriterion(DataParallel):
def forward(self, inputs, *targets, **kwargs):
# 对label进行分解
targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
# 将model在每个GPU上各复制一份
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
# 执行前向计算,具体函数在这里不做展示,详细可见链接内容
outputs = criterion_parallel_apply(replicas, inputs, targets, kwargs)
return Reduce.apply(*outputs) / len(outputs)

整个并行训练过程如上所述,可结合图1(c)理解。

3. 并行训练中的注意事项

  • 所有的划分都是在batch维度进行
  • batch size必须大于GPUs的数量,最好是保证是其数量的整数倍

4. References