[docs]defenable_2d_with_fsdp()->bool:""" The API registers the extension which is needed for Tensor Parallelism (TP) to work with FullyShardedDataParallel (FSDP). We first parallelize parameters within one module or sub_modules based on a parallelize_plan and will let FSDP reshard the local tensor of distributed parameter which is essentially a DTensor. Return: A `bool` indicated whether extension registration succeeds or not. """torch._C._log_api_usage_once("torch.distributed.tensor.parallel.enable_2d_with_fsdp")try:fromtorch.distributed.fsdp._fsdp_extensionsimport(_set_fsdp_extensions,FSDPExtensions,)classDTensorExtensions(FSDPExtensions):defpre_flatten_transform(self,tensor:torch.Tensor,)->Tuple[torch.Tensor,Optional[Any]]:return_flatten_tensor(tensor)defpost_unflatten_transform(self,tensor:torch.Tensor,param_extension:Any)->torch.Tensor:return_unflatten_tensor(tensor,param_extension)defchunk_tensor(self,tensor:torch.Tensor,rank:int,world_size:int,num_devices_per_node:int,pg:dist.ProcessGroup,)->torch.Tensor:return_chunk_tensor(tensor,rank,world_size,num_devices_per_node,pg)defpre_load_state_dict_transform(self,tensor:torch.Tensor,)->Tuple[torch.Tensor,List[Shard]]:return_pre_load_state_dict(tensor)_set_fsdp_extensions(DTensorExtensions())returnTrueexceptBaseExceptionase:warnings.warn("PyTorch doesn't have TensorFlattener extension point available""2D parallelism won't work with FSDP"f"exception: {e}")returnFalse
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.