[docs]defenable_2d_with_fsdp()->bool:""" The API registers the extension which is needed for Tensor Parallelism (TP) to work with FullyShardedDataParallel (FSDP). We first parallelize parameters within one module or sub_modules based on a parallelize_plan and will let FSDP reshard the local tensor of distributed parameter which is essentially a DTensor. Return: A `bool` indicated whether extension registration succeeds or not. """try:fromtorch.distributed.fsdp._fsdp_extensionsimport(_set_fsdp_extensions,FSDPExtensions,)classDTensorExtensions(FSDPExtensions):defpre_flatten_transform(self,tensor:torch.Tensor,)->Tuple[torch.Tensor,Optional[_STShardingInfo]]:return_flatten_tensor(tensor)defpost_unflatten_transform(self,tensor:torch.Tensor,param_extension:_STShardingInfo)->torch.Tensor:return_unflatten_tensor(tensor,param_extension)defchunk_tensor(self,tensor:torch.Tensor,rank:int,world_size:int,num_devices_per_node:int,pg:dist.ProcessGroup,)->torch.Tensor:return_chunk_tensor(tensor,rank,world_size,num_devices_per_node,pg)defpre_load_state_dict_transform(self,tensor:torch.Tensor,)->Tuple[torch.Tensor,List[Shard]]:return_pre_load_state_dict(tensor)_set_fsdp_extensions(DTensorExtensions())returnTrueexceptBaseExceptionase:warnings.warn("PyTorch doesn't have TensorFlattener extension point available""2D parallelism won't work with FSDP"f"exception: {e}")returnFalse
class_STShardingInfo(NamedTuple):""":class:`ShardedTensor` sharding information."""sharding_spec:Optional[shard_spec.ShardingSpec]global_size:Optional[torch.Size]process_group:Optional[c10d.ProcessGroup]device_mesh:Optional[DeviceMesh]placements:Optional[List[Placement]]def_get_box(tensor:DistributedTensor)->Tuple[torch.Size,torch.Size]:device_mesh=tensor.device_meshassertdevice_mesh.ndim==1,"Only 1D DeviceMeshes currently handled"placement=tensor.placements[0]offsets=[0]*len(tensor.size())num_chunks=device_mesh.size(dim=0)iftensor.placements[0].is_shard():shard_dim=cast(DShard,placement).dimchunk_size=tensor.size(shard_dim)//num_chunksoffsets[shard_dim]=chunk_sizereturn(torch.Size(offsets),tensor._local_tensor.size())def_get_box_for(tensor:DistributedTensor,idx:int)->Tuple[torch.Size,torch.Size]:offsets,size=_get_box(tensor)return(torch.Size([val*idxforvalinoffsets]),size)def_get_local_box(tensor:DistributedTensor)->Tuple[torch.Size,torch.Size]:device_mesh=tensor.device_meshdim_0_coord=device_mesh.get_coordinate_on_dim(0)assertdim_0_coordisnotNonereturn_get_box_for(tensor,dim_0_coord)def_create_shard_md_from_dt(dt:DistributedTensor,current_rank:int)->ShardMetadata:mesh=dt.device_meshassertmesh.ndim==1,"Only 1D DeviceMeshes currently handled"offsets,sizes=_get_local_box(dt)returnShardMetadata(shard_offsets=list(offsets),shard_sizes=list(sizes),placement=f"rank:{current_rank}/{dt._local_tensor.device}",)def_create_sharded_tensor_md_from_dt(dt:DistributedTensor,dt_pg:c10d.ProcessGroup)->ShardedTensorMetadata:# This is where it gets tricky, we have to produce a ShardedTensor that has full coverage# and yet has only one valid shard for the current rank.shards_md=[]my_rank=dist.get_rank(dt_pg)scapegoat_rank=0ifmy_rank>0else1ifdt.placements[0].is_shard():shard_count=dt_pg.size()else:shard_count=1foriinrange(shard_count):offsets,sizes=_get_box_for(dt,i)shards_md.append(ShardMetadata(shard_offsets=list(offsets),shard_sizes=list(sizes),placement=(f"rank:{scapegoat_rankifi>0elsemy_rank}/{dt._local_tensor.device}"),))returnShardedTensorMetadata(shards_metadata=shards_md,size=dt.size(),tensor_properties=TensorProperties(dtype=dt.dtype,layout=dt.layout,requires_grad=dt.requires_grad,# ignore memory_format and pin_memory as those are not supported by DT),)def_get_dt_pg(dt:DistributedTensor)->c10d.ProcessGroup:mesh=dt.device_meshassertmesh.ndim==1,"Only 1D DeviceMeshes currently handled"returnmesh.get_dim_groups()[0]def_rewrite_spec_if_needed(spec:shard_spec.ShardingSpec,tensor:torch.Tensor,rank:int)->shard_spec.ShardingSpec:""" Rewrite ``spec`` to match the device of ``tensor``. FSDP.sharded_optim_state_dict sneakly ships optimizer state to CPU so if the original ShardingSpec produces CUDA metadata, ST construction bombs. """ifnotisinstance(spec,ChunkShardingSpec):returnspec# let's see if we needrewrite=Falseforpinspec.placements:p=cast(_remote_device,p)ifp.rank()==rankandp.device()!=tensor.device:rewrite=Truebreakifrewrite:spec=copy.deepcopy(spec)fori,placementinenumerate(spec.placements):placement=cast(_remote_device,placement)ifplacement.rank()==rankandplacement.device()!=tensor.device:spec.placements[i]=_remote_device(f"rank:{rank}/{tensor.device}")returnspecdef_flatten_tensor(tensor:torch.Tensor,)->Tuple[torch.Tensor,Optional[_STShardingInfo]]:iftype(tensor)isShardedTensor:returntensor.local_tensor(),_STShardingInfo(tensor.sharding_spec(),tensor.size(),tensor._process_group,None,None,)eliftype(tensor)isDistributedTensor:tensor._local_tensor.requires_grad_()returntensor._local_tensor,_STShardingInfo(None,None,None,tensor.device_mesh,list(tensor.placements),)returntensor,Nonedef_unflatten_tensor(tensor:torch.Tensor,sharding_info:_STShardingInfo)->torch.Tensor:result:torch.Tensorifsharding_info.sharding_specisnotNone:assertsharding_info.global_sizeisnotNoneresult=ShardedTensor._init_from_local_tensor(tensor,_rewrite_spec_if_needed(sharding_info.sharding_spec,tensor,dist.get_rank(sharding_info.process_group),),sharding_info.global_size,process_group=cast(dist.ProcessGroup,sharding_info.process_group),)else:result=DistributedTensor.from_local(tensor,device_mesh=sharding_info.device_mesh,placements=sharding_info.placements,run_check=False,)_set_fsdp_flattened(result)returnresultdef_chunk_tensor(tensor:torch.Tensor,rank:int,world_size:int,num_devices_per_node:int,pg:dist.ProcessGroup,)->torch.Tensor:iftype(tensor)isShardedTensor:assertlen(tensor.local_shards())==1inner_param=tensor.local_tensor()inner_st=_create_chunk_sharded_tensor(inner_param,rank,world_size,num_devices_per_node,pg,)outer_local_shard=tensor.local_shards()[0]shards:List[Shard]=[Shard(inner_st,copy.deepcopy(outer_local_shard.metadata))]st_meta=copy.deepcopy(tensor.metadata())st_meta.tensor_properties.requires_grad=Falsest_outer=ShardedTensor._init_from_local_shards_and_global_metadata(shards,sharded_tensor_metadata=st_meta,process_group=tensor._process_group,init_rrefs=False,)returnst_outereliftype(tensor)isDistributedTensor:device_mesh=tensor.device_meshassertdevice_mesh.ndim==1,"Only 1D DeviceMeshes currently handled"inner_param=tensor._local_tensorinner_st=_create_chunk_sharded_tensor(inner_param,rank,world_size,torch.cuda.device_count(),pg,)dt_pg=_get_dt_pg(tensor)# We do this differently here, we create a ST with no local shards then patch itshards=[Shard(inner_st,_create_shard_md_from_dt(tensor,dist.get_rank(dt_pg)))]st_meta=_create_sharded_tensor_md_from_dt(tensor,dt_pg)st_meta.tensor_properties.requires_grad=Falsest_outer=ShardedTensor._init_from_local_shards_and_global_metadata(shards,sharded_tensor_metadata=st_meta,process_group=dt_pg,init_rrefs=False,)returnst_outerelse:return_create_chunk_sharded_tensor(tensor,rank,world_size,num_devices_per_node,pg,)def_pre_load_state_dict(tensor:torch.Tensor,)->Tuple[torch.Tensor,List[Shard]]:shards=cast(ShardedTensor,tensor).local_shards()iflen(shards)==1andtype(shards[0].tensor)isShardedTensor:inner_tensor=shards[0].tensorshards=inner_tensor.local_shards()# pyre-ignore[16]tensor=inner_tensorreturn(tensor,shardsiflen(shards)>0else[])
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.