Distributed training¶
The core TNT framework makes no assumptions about distributed training or devices, and expects the user to handle configuring distributed training on their own. As a convenience, the framework offers the AutoUnit
for users who prefer for this to be handled automatically. The framework-provided checkpointing callbacks handle distributed model checkpointing and loading.
If you are using the the TrainUnit
/ EvalUnit
/ PredictUnit
interface, you are expected to initialize the CUDA device, if applicable, along with the global process group from torch.distributed. We offer a convenience function init_from_env()
that works with TorchElastic to automatically handle these settings for you, which you should invoke at the beginning of your script.
Distributed Data Parallel¶
If you are using the the TrainUnit
/ EvalUnit
/ PredictUnit
interface, DDP can be simply be wrapped around your model like so:
device = init_from_env()
module = nn.Linear(input_dim, 1)
# move module to device
module = module.to(device)
# wrap module in DDP
device_ids = [device.index]
model = torch.nn.parallel.DistributedDataParallel(module, device_ids=device_ids)
We also offer prepare_ddp()
which can assist in wrapping the model for you.
The AutoUnit
automatically wraps the module in DDP when either
The string
ddp
is passed in the strategy argumentmodule = nn.Linear(input_dim, 1) my_auto_unit = MyAutoUnit(module=module, strategy="ddp")
The dataclass
DDPStrategy
is passed in to the strategy argument. This is helpful when wanting to customize the settings in DDPmodule = nn.Linear(input_dim, 1) ddp_strategy = DDPStrategy(broadcast_buffers=False, check_reduction=True) my_auto_unit = MyAutoUnit(module=module, strategy=ddp_strategy)
Fully Sharded Data Parallel¶
If using one or more of or TrainUnit
, EvalUnit
, or PredictUnit
, FSDP can be simply be wrapped around the model like so:
device = init_from_env()
module = nn.Linear(input_dim, 1)
# move module to device
module = module.to(device)
# wrap module in FSDP
model = torch.distributed.fsdp.FullyShardedDataParallel(module, device_id=device)
We also offer prepare_fsdp()
which can assist in wrapping the model for you.
The AutoUnit
automatically wraps the module in FSDP when either
The string
fsdp
is passed in the strategy argumentmodule = nn.Linear(input_dim, 1) my_auto_unit = MyAutoUnit(module=module, strategy="fsdp")
The dataclass
FSDPStrategy
is passed in to the strategy argument. This is helpful when wanting to customize the settings in FSDPmodule = nn.Linear(input_dim, 1) fsdp_strategy = FSDPStrategy(forward_prefetch=True, limit_all_gathers=True) my_auto_unit = MyAutoUnit(module=module, strategy=fsdp_strategy)