If your train script works with
torch.distributed.launch it will continue
torchrun with these differences:
No need to manually pass
rdzv_endpointcan be provided. For most users this will be set to
c10d(see rendezvous). The default
rdzv_backendcreates a non-elastic rendezvous where
rdzv_endpointholds the master address.
Make sure you have a
save_checkpoint(path)logic in your script. When any number of workers fail we restart all the workers with the same program arguments so you will lose progress up to the most recent checkpoint (see elastic launch).
use_envflag has been removed. If you were parsing local rank by parsing the
--local_rankoption, you need to get the local rank from the environment variable
Below is an expository example of a training script that checkpoints on each epoch, hence the worst-case progress lost on failure is one full epoch worth of training.
def main(): args = parse_args(sys.argv[1:]) state = load_checkpoint(args.checkpoint_path) initialize(state) # torch.distributed.run ensures that this will work # by exporting all the env vars needed to initialize the process group torch.distributed.init_process_group(backend=args.backend) for i in range(state.epoch, state.total_num_epochs) for batch in iter(state.dataset) train(batch, state.model) state.epoch += 1 save_checkpoint(state)
For concrete examples of torchelastic-compliant train scripts, visit our examples page.