Source code for torch.distributed.pipelining.microbatch
# mypy: allow-untyped-defs# Copyright (c) Meta Platforms, Inc. and affiliatesimportloggingfromtypingimportAny,Dict,List,Optional,Tupleimporttorchfromtorch.fx.nodeimportmap_aggregatefromtorch.utils._pytreeimporttree_flatten,tree_unflatten__all__=["TensorChunkSpec","split_args_kwargs_into_chunks","merge_chunks",]logger=logging.getLogger(__name__)"""_debug_mask_minibatches specifies to send masked versions of the mini-batchthrough instead of micro-batch slices--this can be used for more stablenumerical testing (see [A Note About Correctness Testing])"""_debug_mask_minibatches=Falseclass_CustomReducer:""" Custom reducer class that can be used to specify a custom operation that reduces losses of multiple microbatches into one value. Example: >>> # xdoctest: +SKIP >>> sum_reducer = _CustomReducer( >>> torch.tensor(0.0), >>> lambda a, b: a + b >>> ) """def__init__(self,init_value,reduce_fn):self.init_value=init_valueself.reduce_fn=reduce_fnclass_LossReducer(_CustomReducer):passsum_reducer=_LossReducer(torch.tensor(0.0),lambdaa,b:a+b)# Default chunking dimension is 0. This is used for the case where the user did# not specify a chunking dimension.DEFAULT_CHUNK_DIM=0
[docs]classTensorChunkSpec:""" Class used to specify chunking of inputs """def__init__(self,split_dim):self.split_dim=split_dimsplit_dim:intdef__repr__(self):return(f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})")def__str__(self):returnf"TensorChunkSpec({self.split_dim})"@staticmethoddeffrom_tuple(chunk_dims:Tuple[int,...],):""" A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk dimensions (int's). Example: >>> # xdoctest: +SKIP >>> # There are three positional arguments to the model, and >>> # we are chunking them along dimension 0, 0 and 1, respectively >>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1)) """args_chunk_spec=map_aggregate(chunk_dims,lambdadim:TensorChunkSpec(dim),# type: ignore[arg-type,return-value])returnargs_chunk_spec@staticmethoddeffrom_dict(chunk_dims:Dict[str,int],):""" A helper for creating a dictionary of `TensorChunkSpec` from a dictionary of chunk dimensions (int's). Example: >>> # xdoctest: +SKIP >>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument >>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1}) """kwargs_chunk_spec=map_aggregate(chunk_dims,lambdadim:TensorChunkSpec(dim),# type: ignore[arg-type,return-value])returnkwargs_chunk_spec
# Class used to specify replication of inputsclass_Replicate:passdef_shard_dict_of_args(args_dict,args_chunk_spec,num_chunks,):""" Given a dictionary of args, and a dictionary of chunking specs, shard the args according to the chunking specs. Args: args_dict: Dictionary of args args_chunk_spec: Dictionary of chunking specs num_chunks: Number of chunks to shard the args into Returns: args_split: List of sharded args """# Stage 1+2: flatten and shard/replicate# args_sharded_replicated : [num args, num flat values, num chunks]args_sharded_replicated={}arg_specs=[]real_num_chunks=num_chunksfirst_tensor=Trueassertlen(args_dict)==len(args_chunk_spec),f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"forarg_key,arginargs_dict.items():flat,spec=tree_flatten(arg)arg_specs.append(spec)chunk_spec=args_chunk_spec[arg_key]assertchunk_specisnotNone# Should have been set by callerchunk_spec_flat,_=tree_flatten(chunk_spec)iflen(flat)!=len(chunk_spec_flat):raiseValueError(f"Argument value {arg} did not have the same number of "f"values as as chunk spec {chunk_spec}")sharded_arg_flat=[]forv,chunk_vinzip(flat,chunk_spec_flat):ifchunk_vis_Replicateornotisinstance(v,torch.Tensor):sharded_arg_flat.append([v]*real_num_chunks)elifisinstance(chunk_v,TensorChunkSpec):# TODO: check type of v. If it's a tensor, use chunk (or debug mask).# If it's a collection type, split it as you would expect. Otherwise,# Throw an errorassertisinstance(v,torch.Tensor),f"{v} is not a tensor"v_split_dim_size=v.size(chunk_v.split_dim)ifv_split_dim_size<real_num_chunks:iffirst_tensor:# We can only adjust number of chunks when we hit this# issue at the first tensor encounteredlogger.warning(f"Tensor size on chunking dimension is {v_split_dim_size}, "# noqa: G004f"downsizing the number of chunks from {num_chunks} to {v_split_dim_size}.")real_num_chunks=v_split_dim_sizeelse:raiseRuntimeError(f"Arg {arg_key} on chunking dimension has a size of {v_split_dim_size}, "f"smaller than the number of chunks {num_chunks}. ""PiPPy cannot reduce the number of chunks because ""other arguments have bigger chunk-dimension sizes. ""Please adjust your num_chunks setting.")chunk_tensors=torch.tensor_split(v,real_num_chunks,chunk_v.split_dim)if_debug_mask_minibatches:expanded_chunks=[]split_dim_idx=0forchunk_tensorinchunk_tensors:new_val=torch.zeros_like(v)upper_idx=split_dim_idx+chunk_tensor.size(chunk_v.split_dim)slice_indices=[slice(None,None,None)]*new_val.ndimslice_indices[chunk_v.split_dim]=slice(split_dim_idx,upper_idx)new_val[slice_indices]=chunk_tensorexpanded_chunks.append(new_val)split_dim_idx+=chunk_tensor.size(chunk_v.split_dim)sharded_arg_flat.append(expanded_chunks)else:sharded_arg_flat.append(chunk_tensors)# type: ignore[arg-type]first_tensor=Falseelse:raiseTypeError(f"Unrecognized chunk spec: {chunk_v}")args_sharded_replicated[arg_key]=sharded_arg_flat# chunks_flat : [num chunks, num args, num flat values]chunks_flat=[]forchunk_idxinrange(real_num_chunks):chunk_args={}forkey,arginargs_sharded_replicated.items():arg_single_chunk=[v_flat[chunk_idx]forv_flatinarg]chunk_args[key]=arg_single_chunkchunks_flat.append(chunk_args)# args_split : [num chunks, num args]args_split=[]forchunkinchunks_flat:per_chunk_args={}assertlen(arg_specs)==len(chunk)for(key,arg),arg_specinzip(chunk.items(),arg_specs):per_chunk_args[key]=tree_unflatten(arg,arg_spec)args_split.append(per_chunk_args)returnargs_split
[docs]defsplit_args_kwargs_into_chunks(args:Tuple[Any,...],kwargs:Optional[Dict[str,Any]],chunks:int,args_chunk_spec:Optional[Tuple[TensorChunkSpec,...]]=None,kwargs_chunk_spec:Optional[Dict[str,TensorChunkSpec]]=None,)->Tuple[List[Tuple],List[Dict]]:""" Given a sequence of args and kwargs, split them into a number of chunks according to their respective chunking specs. Args: args: Tuple of args kwargs: Dict of kwargs chunks: Number of chunks to split the args and kwargs into args_chunk_spec: chunking specs for args, in same shape as args kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs Returns: args_split: List of sharded args kwargs_split: List of sharded kwargs """# Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that# the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec`# and `kwargs_chunk_spec` specifications. The steps are as follows:## 1. Use pytree.tree_flatten to flatten each arg and its spec into nto a 1d array of values.# To use a running example: suppose our inputs look like## args = ([A, [B, C]], D) args_spec = ([None, [None, TensorChunkSpec]], None)# (kwargs not shown but it's a similar process)## Then for this step we would end up with## args = ([A, B, C], D) args_spec = ([None, None, TensorChunkSpec], None)## 2. Shard or replicate the arguments subject to the policy in the spec. Suppose chunks = 2## args = ([[A, A], [B, B], [C_1, C_2]], [D, D])## 3. Rotate the nesting order such that chunks are the outer dimension## args_chunks = [# ([A, B, C_1], D),# ([A, B, C_2], D),# ]## 4. Unflatten each chunk according to the spec## args_chunks = [# ([A, [B, C_1]], D),# ([A, [B, C_2]], D),# ]# TODO: _debug_mask_minibatches# Handle the case where kwargs is NoneifkwargsisNone:kwargs={}# If user did not provide args_chunk_spec or kwargs_chunk_spec, we extend# their format and use default chunking along dim 0ifargs_chunk_specisNone:args_chunk_spec=(TensorChunkSpec(DEFAULT_CHUNK_DIM),)*len(args)ifkwargs_chunk_specisNone:kwargs_chunk_spec=dict.fromkeys(kwargs,TensorChunkSpec(DEFAULT_CHUNK_DIM))args_split_dict=_shard_dict_of_args(dict(enumerate(args)),dict(enumerate(args_chunk_spec)),chunks,)real_num_chunks=len(args_split_dict)kwargs_split=_shard_dict_of_args(kwargs,kwargs_chunk_spec,real_num_chunks,)iflen(kwargs_split)<real_num_chunks:# In case kwargs are sharded into less chunks# e.g. when `args` has no tensor, just valuesreal_num_chunks=len(kwargs_split)# Re-shard argsargs_split_dict=_shard_dict_of_args(dict(enumerate(args)),dict(enumerate(args_chunk_spec)),real_num_chunks,)iflen(args_split_dict)!=len(kwargs_split):raiseRuntimeError("args and kwargs are split into different number of chunks: "f"{len(args_split_dict)}, {len(kwargs_split)}")args_split=[tuple(chunk_args[i]foriinrange(len(chunk_args)))forchunk_argsinargs_split_dict]returnargs_split,kwargs_split
[docs]defmerge_chunks(chunks:List[Any],chunk_spec,):""" Given a list of chunks, merge them into a single value according to the chunk spec. Args: chunks: list of chunks chunk_spec: Chunking spec for the chunks Returns: value: Merged value """# This is essentially the inverse of `split_args_kwargs_into_chunks`, so the# steps are similar to the steps in that function but in reverse. Given the# input values:## chunks = [# ([A, [B, C_1]], D),# ([A, [B, C_2]], D),# ]# args_spec = ([None, [None, TensorChunkSpec]], None)## 1. Flatten the chunks according to the chunk_spec## chunks_flat = [# ([A, B, C_1], D),# ([A, B, C_2], D),# ]## 2. Rotate the nesting order such that chunks are the inner dimension## value_inner = ([A, B, [C_1, C_2]], D)## 3. Concatenate sharded arguments## value_combined = ([A, B, C], D)## 4. Unflatten the combined args given the spec## value = ([A, [B, C]], D)# Preliminary: flatten the chunk specifchunk_specisnotNone:spec_flattened,flatten_spec=tree_flatten(chunk_spec)else:# If chunk_spec is not provided, we will merge chunks along the default dimension (0), for all output fields# We obtain the output structure by flattening chunk 0 and generate the chunk_specchunk0_flat,flatten_spec=tree_flatten(chunks[0])spec_flattened=[TensorChunkSpec(DEFAULT_CHUNK_DIM)]*len(chunk0_flat)# Stage 1: flatten chunks# chunks_flattened : [num chunks, num args]chunks_flattened=[]forchunkinchunks:chunk_flattened,_=tree_flatten(chunk)iflen(chunk_flattened)!=len(spec_flattened):raiseValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}")chunks_flattened.append(chunk_flattened)# Stage 2 and 3: Rotate nesting order s.t. chunks are inner dimension and# concatenate sharded operands# args_flattened : [num args]args_flattened=[]forarg_idx,arginenumerate(spec_flattened):ifisinstance(arg,TensorChunkSpec):partial_values=[chunks_flattened[chunk_idx][arg_idx]forchunk_idxinrange(len(chunks_flattened))]if_debug_mask_minibatches:# Infer size of individual chunks by running `tensor_split` againoverall_shape=partial_values[0].shapeforvalinpartial_values[1:]:assertval.shape==overall_shapemeta_chunks=torch.tensor_split(torch.empty(*overall_shape,device="meta"),sections=len(partial_values),dim=arg.split_dim,)values_to_cat=[]chunk_start_idx=0assertlen(partial_values)==len(meta_chunks)forpartial_value,meta_chunkinzip(partial_values,meta_chunks):chunk_end_idx=chunk_start_idx+meta_chunk.size(arg.split_dim)slice_indices=[slice(None,None,None)]*partial_value.ndimslice_indices[arg.split_dim]=slice(chunk_start_idx,chunk_end_idx)sliced=partial_value[slice_indices]values_to_cat.append(sliced)chunk_start_idx=chunk_end_idxelse:values_to_cat=partial_valuesargs_flattened.append(torch.cat(values_to_cat,dim=arg.split_dim))elifisinstance(arg,_CustomReducer):reduced_val=arg.init_valueforchunk_idxinrange(len(chunks_flattened)):reduced_val=arg.reduce_fn(reduced_val,chunks_flattened[chunk_idx][arg_idx])args_flattened.append(reduced_val)else:value=chunks_flattened[0][arg_idx]forchunk_idxinrange(1,len(chunks_flattened)):assertchunks_flattened[chunk_idx][arg_idx]==valueargs_flattened.append(value)# Stage 4: Unflatten combined argsreturntree_unflatten(args_flattened,flatten_spec)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.