Shortcuts

Distributed training

The core TNT framework makes no assumptions about distributed training or devices, and expects the user to handle configuring distributed training on their own. As a convenience, the framework offers the AutoUnit for users who prefer for this to be handled automatically. The framework-provided checkpointing callbacks handle distributed model checkpointing and loading.

If you are using the the TrainUnit/ EvalUnit/ PredictUnit interface, you are expected to initialize the CUDA device, if applicable, along with the global process group from torch.distributed. We offer a convenience function init_from_env() that works with TorchElastic to automatically handle these settings for you, which you should invoke at the beginning of your script.

Distributed Data Parallel

If you are using the the TrainUnit/ EvalUnit/ PredictUnit interface, DDP can be simply be wrapped around your model like so:

device = init_from_env()
module = nn.Linear(input_dim, 1)
# move module to device
module = module.to(device)
# wrap module in DDP
device_ids = [device.index]
model = torch.nn.parallel.DistributedDataParallel(module, device_ids=device_ids)

We also offer prepare_ddp() which can assist in wrapping the model for you.

The AutoUnit automatically wraps the module in DDP when either

  1. The string ddp is passed in the strategy argument

    module = nn.Linear(input_dim, 1)
    my_auto_unit = MyAutoUnit(module=module, strategy="ddp")
    
  2. The dataclass DDPStrategy is passed in to the strategy argument. This is helpful when wanting to customize the settings in DDP

    module = nn.Linear(input_dim, 1)
    ddp_strategy = DDPStrategy(broadcast_buffers=False, check_reduction=True)
    my_auto_unit = MyAutoUnit(module=module, strategy=ddp_strategy)
    

Fully Sharded Data Parallel

If using one or more of or TrainUnit, EvalUnit, or PredictUnit, FSDP can be simply be wrapped around the model like so:

device = init_from_env()
module = nn.Linear(input_dim, 1)
# move module to device
module = module.to(device)
# wrap module in FSDP
model = torch.distributed.fsdp.FullyShardedDataParallel(module, device_id=device)

We also offer prepare_fsdp() which can assist in wrapping the model for you.

The AutoUnit automatically wraps the module in FSDP when either

  1. The string fsdp is passed in the strategy argument

    module = nn.Linear(input_dim, 1)
    my_auto_unit = MyAutoUnit(module=module, strategy="fsdp")
    
  2. The dataclass FSDPStrategy is passed in to the strategy argument. This is helpful when wanting to customize the settings in FSDP

    module = nn.Linear(input_dim, 1)
    fsdp_strategy = FSDPStrategy(forward_prefetch=True, limit_all_gathers=True)
    my_auto_unit = MyAutoUnit(module=module, strategy=fsdp_strategy)
    

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources