转载自 https://medium.com/codex/a-comprehensive-tutorial-to-pytorch-distributeddataparallel-1f4b42bb1b51
The limited computation resource at school discourages distibuted training across multiple gpus. I started to learn it for the first time when I joined Microsoft as an intern. It’s basically an easy job to wrap the model with DDP (short for DistributedDataParallel). What frustrated me was that I cannot properly adjust my workflow for multi-gpu, including DataLoader
, Sampler
, training and evaluating. The tutorials and blogs on Internet hardly includes all these stuff. After addressing so many bugs I came across, I’ve come up with the best practice so far.
In this blog, I want to share my code, my insighs with all beginners in DDP. I hope this blog will help them to avoid horrible bugs and mistakes. I’m not going to include detailed explanation of how DDP works, instead, I provide minimum knowledge needed to make the model run in multiple gpus. Note that I only introduce DDP on one machine with multiple gpus, which is the most general case (Otherwise, we should use model parallel as stated in the official blog). This blog is organized as:
BTW, I’m using torch==1.7.1
, but I think it will work just fine in torch>=1.7.1
.
First we must understand several terms used in distributed training:
Pytorch provides two settings for distributed training: torch.nn.DataParallel (DP) and torch.nn.parallel.DistributedDataParallel (DDP), where the latter is officially recommended. In short, DDP is faster, more flexible than DP. The fundamental thing DDP does is to copy the model to multiple gpus, gather the gradients from them, average the gradients to update the model, then synchronize the model over all K processes. We can also gather/scatter tensors/objects other than gradients by torch.distributed.gather/scatter/reduce
.
In case the model can fit on one gpu (it can be trained on one gpu with batch_size=1
) and we want to train/test it on K gpus, the best practice of DDP is to copy the model onto the K gpus (the DDP class automatically does this for you) and split the dataloader to K non-overlapping groups to feed into K models respectively.
Now, things are clear to us. We have to do the following things:
Very easy, right? In fact it is. Let’s do it step by step.
Here it is, no extra steps.
import torch.distributed as distdef setup(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group("nccl", rank=rank, world_size=world_size)
We can easily split our dataloader by torch.utils.data.distributed.DistributedSampler. The sampler returns a iterator over indices, which are fed into dataloader to bachify.
The DistributedSampler split the total indices of the dataset into world_size parts, and evenly distributes them to the dataloader in each process without duplication.
from torch.utils.data.distributed import DistributedSamplerdef prepare(rank, world_size, batch_size=32, pin_memory=False, num_workers=0): dataset = Your_Dataset() sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False) dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers, drop_last=False, shuffle=False, sampler=sampler) return dataloader
Suppose K=3, and the length of dataset is 10. We must understand that DistributedSampler imposes even partition of indices.
drop_last=False
when defining DistributedSampler, it will automatically pad. For example, it splits indices [0,1,2,3,4,5,6,7,8,9] to [0,3,6,9] when rank=1, [0,4,7,0] when rank=2, and [2,5,8,0] when rank=3. As you can see, such padding may cause issues because the padded 0 is a data record.It is very simple to customize our Sampler. We only need to create a class, then define its __iter__()
and __len__()
function. Refer to the official documentation for more details.
BTW, you’d better set the num_workers=0
when distributed training, because creating extra threads in the children processes may be problemistic. I also found pin_memory=False
avoids many horrible bugs, maybe such things are machine-specific, please email me if you readers explored more details.
We should first move our model to the specific gpu (recall that one model replica resides in one gpu), then we wrap it with DDP class. The following function takes in an argument rank, which we will introduce soon. For now, we just keep in mind rank equals the gpu id.
from torch.nn.parallel import DistributedDataParallel as DDPdef main(rank, world_size): # setup the process groups setup(rank, world_size) # prepare the dataloader dataloader = prepare(rank, world_size) # instantiate the model(it's your own model) and move it to the right device model = Model().to(rank) # wrap the model with DDP # device_ids tell DDP where is your model # output_device tells DDP where to output, in our case, it is rank # find_unused_parameters=True instructs DDP to find unused output of the forward() function of any module in the model model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)
There are a few tricky things here:
model.module
. That is to say, our model instance is saved as a module
attribute of the DDP model. If we assign some attributes xxx
other than built-in properties or functions, we must access them by model.module.xxx
.state_dict
would add a module prefix to all parameters.# in case we load a DDP model checkpoint to a non-DDP modelmodel_dict = OrderedDict() pattern = re.compile('module.')for k,v in state_dict.items(): if re.search("module", k): model_dict[re.sub(pattern, '', k)] = v else: model_dict = state_dictmodel.load_state_dict(model_dict)
This part is the key to implementing DDP. First we need to know the basis of multi-processing: all children processes together with the parent process run the same code.
In PyTorch, torch.multiprocessing provides convenient ways to create parallel processes. As the official documentation says,
The
spawn
function below addresses these concerns and takes care of error propagation, out of order termination, and will actively terminate processes upon detecting an error in one of them.
So, using spawn
is a good choice.
In our script, we should define a train/test function before spawning it to parallel processes:
def main(rank, world_size): # setup the process groups setup(rank, world_size) # prepare the dataloader dataloader = prepare(rank, world_size) # instantiate the model(it's your own model) and move it to the right device model = Your_Model().to(rank) # wrap the model with DDP # device_ids tell DDP where is your model # output_device tells DDP where to output, in our case, it is rank # find_unused_parameters=True instructs DDP to find unused output of the forward() function of any module in the model model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True) #################### The above is defined previously optimizer = Your_Optimizer() loss_fn = Your_Loss() for epoch in epochs: # if we are using DistributedSampler, we have to tell it which epoch this is dataloader.sampler.set_epoch(epoch) for step, x in enumerate(dataloader): optimizer.zero_grad(set_to_none=True) pred = model(x) label = x['label'] loss = loss_fn(pred, label) loss.backward() optimizer.step() cleanup()
This main
function is run in every parallel process. We now need to call it by spawn
method. In our .py
script, we write:
import torch.multiprocessing as mp if __name__ == '__main__': # suppose we have 3 gpus world_size = 3 mp.spawn( main, args=(world_size), nprocs=world_size )
Remember the first argument of main
is rank? It is automatically passed to each process by mp.spawn
, we don’t need to pass it explicitly. rank=0
is the master node by default. The rank ranges from 0 to K-1 (2 in our case).
The last line of main
function is the clean up function, which is:
def cleanup(): dist.destroy_process_group()
Bravo! We have completed the basic workflow of Distributed training/tesing!
Sometimes we need to collect some data from all processes, such as the testing result. We can easily gather tensors by dist.all_gather
and objects by dist.all_gather_object
.
Without loss of generality, I assume we want to collect python objects. The only constraint of the object is it must be serializable, which is basically everything in python. One should always assign torch.cuda.set_device(rank)
before using all_gather_xxx
. And, if we want to store a tensor in the object, it must locate at the output_device
.
def main(rank, world_size): torch.cuda.set_device(rank) data = { 'tensor': torch.ones(3,device=rank) + rank, 'list': [1,2,3] + rank, 'dict': {'rank':rank} } # we have to create enough room to store the collected objects outputs = [None for _ in range(world_size)] # the first argument is the collected lists, the second argument is the data unique in each process dist.all_gather_object(outputs, data) # we only want to operate on the collected objects at master node if rank == 0: print(outputs)
The most confusing thing to me is when to use dist.barrier()
. As the documentation says, it synchronizes processes. In other words, it blocks processes until all of them reaches the same line of code: dist.barrier()
. I summarize its usage as follows:
loss.backward()
);dist.all_gather_object
does it for us;In this post, we learnt how to implement DDP in our models from scratch. Hopefully everyone read this could benefit. Thank you.