Source code for torch.distributed.pipelining.schedules
# mypy: allow-untyped-defs# Copyright (c) Meta Platforms, Inc. and affiliatesimportcopyimportcsvimportitertoolsimportloggingimportrefromabcimportABC,abstractmethodfromcollectionsimportCounter,defaultdictfromenumimportEnumfromtypingimport(Any,Callable,Dict,List,NamedTuple,Optional,Set,Tuple,TYPE_CHECKING,Union,)importtorchimporttorch.distributedasdistfromtorch.distributed.fsdpimportFSDPModule,UnshardHandlefromtorch.profilerimportrecord_functionfrom.microbatchimportmerge_chunks,split_args_kwargs_into_chunks,TensorChunkSpecfrom.stageimport_PipelineStageBaseifTYPE_CHECKING:fromtorch.distributedimportWork__all__=["get_schedule_class","PipelineScheduleSingle","PipelineScheduleMulti","Schedule1F1B","ScheduleGPipe","ScheduleInterleaved1F1B","ScheduleLoopedBFS","ScheduleInterleavedZeroBubble","ScheduleZBVZeroBubble",]logger=logging.getLogger(__name__)class_ComputationType(Enum):# TODO(whc) rename to _ActType?FORWARD=1BACKWARD_INPUT=2BACKWARD_WEIGHT=3UNSHARD=4RESHARD=5SEND_F=6RECV_F=7SEND_B=8RECV_B=9FULL_BACKWARD=10def__str__(self):str_map={_ComputationType.FORWARD:"F",_ComputationType.BACKWARD_INPUT:"I",_ComputationType.BACKWARD_WEIGHT:"W",_ComputationType.UNSHARD:"UNSHARD",_ComputationType.RESHARD:"RESHARD",_ComputationType.SEND_F:"SEND_F",_ComputationType.RECV_F:"RECV_F",_ComputationType.SEND_B:"SEND_B",_ComputationType.RECV_B:"RECV_B",_ComputationType.FULL_BACKWARD:"B",}returnstr_map[self]@staticmethoddeffrom_str(action):ifaction=="F":return_ComputationType.FORWARDelifaction=="I":return_ComputationType.BACKWARD_INPUTelifaction=="W":return_ComputationType.BACKWARD_WEIGHTelifaction=="UNSHARD":return_ComputationType.UNSHARDelifaction=="RESHARD":return_ComputationType.RESHARDelifaction=="SEND_F":return_ComputationType.SEND_Felifaction=="RECV_F":return_ComputationType.RECV_Felifaction=="SEND_B":return_ComputationType.SEND_Belifaction=="RECV_B":return_ComputationType.RECV_Belifaction=="B":return_ComputationType.FULL_BACKWARDelse:raiseRuntimeError(f"Invalid computation type {action}")FORWARD=_ComputationType.FORWARDBACKWARD_INPUT=_ComputationType.BACKWARD_INPUTBACKWARD_WEIGHT=_ComputationType.BACKWARD_WEIGHTUNSHARD=_ComputationType.UNSHARDRESHARD=_ComputationType.RESHARDSEND_F=_ComputationType.SEND_FRECV_F=_ComputationType.RECV_FSEND_B=_ComputationType.SEND_BRECV_B=_ComputationType.RECV_BFULL_BACKWARD=_ComputationType.FULL_BACKWARD# Convenience shorthand for compute actions only since they are used in 'simple schedule format'F=FORWARDI=BACKWARD_INPUTW=BACKWARD_WEIGHTB=FULL_BACKWARD# Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index)_action_regex=re.compile(r"(\d+)(F|I|B|W|UNSHARD|RESHARD|SEND_F|RECV_F|SEND_B|RECV_B)(\d*)")class_Action(NamedTuple):stage_index:intcomputation_type:_ComputationTypemicrobatch_index:Optional[int]=Nonedef__repr__(self):repr=str(self.stage_index)repr+=str(self.computation_type)ifself.microbatch_indexisnotNone:repr+=str(self.microbatch_index)returnrepr@staticmethoddeffrom_str(action_string:str):""" Reverse of __repr__ String should be formatted as [stage][action type][(microbatch)] e.g. `2F0`, `1UNSHARD`, `3SEND_F1` """action_string=action_string.strip()ifmatch:=_action_regex.match(action_string):stage_index,computation_type,microbatch_index=match.groups()return_Action(int(stage_index),_ComputationType.from_str(computation_type),int(microbatch_index)iflen(microbatch_index)elseNone,)elifaction_string=="":returnNoneraiseRuntimeError(f"Invalid action string: {action_string}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0")def_format_pipeline_order(pipeline_order:Dict[int,List[Optional[_Action]]],error_step_number:Optional[int]=None,)->str:""" Formats the pipeline order in a timestep (row) x rank (column) grid of actions and returns the formatted string. If `error_step_number` is passed in, an additional label will be added to signify which step that it is erroring on. """# don't mutate the originalpipeline_order=copy.deepcopy(pipeline_order)# Replace None with ""forrankinpipeline_order:foriinrange(len(pipeline_order[rank])):ifpipeline_order[rank][i]isNone:# TODO make a real 'None action' that prints as empty string and make mypy happypipeline_order[rank][i]=""# type: ignore[call-overload]# Calculate the maximum number of steps across all ranksnum_steps=max(len(actions)foractionsinpipeline_order.values())step_labels=["Step "+str(i).zfill(len(str(num_steps-1)))foriinrange(num_steps)]# Sorting the dictionary by keys and retrieving values in that orderrank_actions=[pipeline_order.get(key,[""]*num_steps)forkeyinsorted(pipeline_order)]# Transpose the list of lists (rows to columns)transposed_actions=list(itertools.zip_longest(*rank_actions,fillvalue=""))# Generate column labels for ranksnum_ranks=len(pipeline_order)rank_labels=["Rank "+str(i)foriinrange(num_ranks)]# Calculate the maximum length of each column, considering labelsmax_lengths=[max(len(str(item))ifitemisnotNoneelse0foritemincol)forcolinzip(step_labels,*transposed_actions)]# Format the header row with rank labelsheader_row=" "*(len(step_labels[0])+2)+" ".join(f"{label:<{max_lengths[i]}}"fori,labelinenumerate(rank_labels))# Format each row with its corresponding labelformatted_rows=[f"{label}: "+" ".join(f"{str(item):<{max_lengths[i]}}"fori,iteminenumerate(row))+(" <-- ERROR HERE"iferror_step_numberisnotNoneandint(label.split()[1])==error_step_numberelse"")forlabel,rowinzip(step_labels,transposed_actions)]# Join the rows into a single stringformatted_table=header_row+"\n"+"\n".join(formatted_rows)+"\n"returnformatted_tableclass_PipelineSchedule(ABC):def__init__(self,n_microbatches:int,loss_fn:Optional[Callable[...,torch.Tensor]]=None,args_chunk_spec:Optional[Tuple[TensorChunkSpec,...]]=None,kwargs_chunk_spec:Optional[Dict[str,TensorChunkSpec]]=None,output_merge_spec:Optional[Union[Dict[str,Any],Tuple[Any]]]=None,):# From argumentsself._n_microbatches=n_microbatchesself._loss_fn=loss_fn# Chunking specification for positional inputs. (default: `None`)self._args_chunk_spec=args_chunk_spec# Chunking specification for keyword inputs. (default: `None`)self._kwargs_chunk_spec=kwargs_chunk_specself._output_merge_spec=output_merge_spec""" # args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs. # They are used to convert batch to microbatches in `step(x)`. See # `TensorChunkSpec` for helper methods for creating them. """# Derivedself._has_backward=self._loss_fnisnotNone# Holds the losses for each microbatch.self._internal_losses:List[torch.Tensor]=[]logger.info("Using %s",self.__class__.__name__)def_maybe_compute_loss(self,stage,output,target_mbs,mb_index):ifstage.is_lastandself._has_backward:loss=self._compute_loss(output,target_mbs[mb_index])# type: ignore[index]self._internal_losses.append(loss)def_maybe_get_loss(self,stage,mb_index):valid_index=0<=mb_index<len(self._internal_losses)ifstage.is_lastandself._has_backwardandvalid_index:returnself._internal_losses[mb_index]eliflen(self._internal_losses)!=0andnotvalid_index:raiseRuntimeError(f"Loss for microbatch {mb_index} is not available. "f"Available losses for microbatches: {self._internal_losses}")else:returnNonedef_update_losses(self,stages,losses):""" Update the losses to those in the internal state """# if stages not a list turn into a listifnotisinstance(stages,list):stages=[stages]contains_last_stage=any(stage.is_lastforstageinstages)# Return losses if there is a container passed inifcontains_last_stageandlossesisnotNone:iflen(self._internal_losses)!=self._n_microbatches:raiseRuntimeError(f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}")# Clean external container firstlosses.clear()# Copy internal losses to external containerlosses.extend(self._internal_losses)self._internal_losses.clear()@abstractmethoddef_step_microbatches(self,arg_mbs:Optional[List]=None,kwarg_mbs:Optional[List]=None,target_mbs:Optional[List]=None,losses:Optional[List]=None,):""" Run one iteration of the pipeline schedule with list of microbatches. Will go through all the microbatches according to the schedule implementation. Args: microbatches: list of microbatch args. """raiseNotImplementedError@abstractmethoddefstep(self,*args,target=None,losses:Optional[List]=None,**kwargs):""" Run one iteration of the pipeline schedule with *whole-batch* input. Will chunk the input into microbatches automatically, and go through the microbatches according to the schedule implementation. args: positional arguments to the model (as in non-pipeline case). kwargs: keyword arguments to the model (as in non-pipeline case). target: target for the loss function. losses: a list to store the losses for each microbatch. """raiseNotImplementedErrordef_check_inputs(self,arg_mbs:Optional[List]=None,kwarg_mbs:Optional[List]=None,target_mbs:Optional[List]=None,losses:Optional[List]=None,):""" Pre-process/check inputs """defcheck_type_and_len(mbs,name:str):ifnotisinstance(mbs,list):raiseTypeError(f"{name} must be a list but got a {type(mbs)}")iflen(mbs)!=self._n_microbatches:raiseValueError(f"Expecting {self._n_microbatches}{name} but got {len(mbs)}")ifarg_mbsisnotNone:check_type_and_len(arg_mbs,"arg_mbs")else:arg_mbs=[()]*self._n_microbatchesifkwarg_mbsisnotNone:check_type_and_len(kwarg_mbs,"kwarg_mbs")else:kwarg_mbs=[{}]*self._n_microbatchesiftarget_mbsisnotNone:check_type_and_len(target_mbs,"target_mbs")iflossesisnotNone:ifnotisinstance(losses,list):raiseTypeError(f"losses must be a list but got a {type(losses)}")returnarg_mbs,kwarg_mbsdef_compute_loss(self,output,target):returnself._loss_fn(output,target)# type: ignore[misc]def_split_inputs(self,args:Tuple[Any,...],kwargs:Optional[Dict[str,Any]]=None,):""" Splits a full-batch input into chunks (i.e. microbatches) and returns the chunks """ifargsorkwargs:args_split,kwargs_split=split_args_kwargs_into_chunks(args,kwargs,self._n_microbatches,self._args_chunk_spec,self._kwargs_chunk_spec,)returnargs_split,kwargs_splitelse:# Empty inputs (e.g. when called on middle stages)# Return a list of empty tuples/dicts with matching length as chunksreturn[()]*self._n_microbatches,[{}]*self._n_microbatchesdef_merge_outputs(self,output_chunks:List[Any])->Any:""" Merge output chunks back to a batch state. If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim). """returnmerge_chunks(output_chunks,self._output_merge_spec,)def_batch_p2p(p2p_ops:List[dist.P2POp],desc:Optional[str]=None):""" Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top. """iflen(p2p_ops)==0:returnNonedesc_str=f"{desc}, "ifdescelse""logger.debug("batch_p2p %s%s",desc_str,p2p_ops)returndist.batch_isend_irecv(p2p_ops).pop()def_sorted_batch_p2p(p2p_ops:List[dist.P2POp],desc:Optional[str]=None)->Dict[int,dist.Work]:""" Sorts the list of P2P ops by the peer rank, and then calls batch_isend_irecv. Return a dictionary of works by peer rank. This function helps us avoid hangs in case of skip connections. """# Arrange p2p_ops by peer rank:# int is the peer rank;# List is the list of ops towards the peerops_by_peer:Dict[int,List[dist.P2POp]]=defaultdict(list)work_by_peer:Dict[int,dist.Work]={}iflen(p2p_ops)==0:returnwork_by_peer# Classify the ops by peer rankforopinp2p_ops:ops_by_peer[op.peer].append(op)# Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs)forpeer,opsinsorted(ops_by_peer.items()):work_by_peer[peer]=_batch_p2p(ops,desc=desc)returnwork_by_peer
[docs]classPipelineScheduleSingle(_PipelineSchedule):""" Base class for single-stage schedules. Implements the `step` method. Derived classes should implement `_step_microbatches`. """def__init__(self,stage:_PipelineStageBase,n_microbatches:int,loss_fn:Optional[Callable]=None,args_chunk_spec:Optional[Tuple[TensorChunkSpec,...]]=None,kwargs_chunk_spec:Optional[Dict[str,TensorChunkSpec]]=None,output_merge_spec:Optional[Union[Dict[str,Any],Tuple[Any]]]=None,):# Init parentsuper().__init__(n_microbatches=n_microbatches,loss_fn=loss_fn,args_chunk_spec=args_chunk_spec,kwargs_chunk_spec=kwargs_chunk_spec,output_merge_spec=output_merge_spec,)# Self attributesself._stage=stageself._num_stages=stage.num_stages# Set the same has_backward flag for stage objectself._stage.has_backward=self._has_backwardself._stage_initialized=Falsedef_initialize_stage(self,args,kwargs):self._stage._prepare_forward_infra(self._n_microbatches,args,kwargs)ifself._has_backward:self._stage._prepare_backward_infra(self._n_microbatches)self._stage_initialized=True
[docs]defstep(self,*args,target=None,losses:Optional[List]=None,**kwargs):""" Run one iteration of the pipeline schedule with *whole-batch* input. Will chunk the input into microbatches automatically, and go through the microbatches according to the schedule implementation. args: positional arguments to the model (as in non-pipeline case). kwargs: keyword arguments to the model (as in non-pipeline case). target: target for the loss function. losses: a list to store the losses for each microbatch. """# Clean per iterationself._stage.clear_runtime_states()# Split inputs into microbatchesargs_split,kwargs_split=self._split_inputs(args,kwargs)# Split target into microbatchesiftargetisnotNone:targets_split=list(torch.tensor_split(target,self._n_microbatches))else:targets_split=None# Run microbatchesself._step_microbatches(args_split,kwargs_split,targets_split,losses)# Return merged results per original formatifself._stage.is_last:returnself._merge_outputs(self._stage.output_chunks)else:returnNone
class_ScheduleForwardOnly(PipelineScheduleSingle):""" The forward-only schedule. Will go through all the microbatches and perform only the forward pass """def_step_microbatches(self,arg_mbs:Optional[List]=None,kwarg_mbs:Optional[List]=None,target_mbs:Optional[List]=None,losses:Optional[List]=None,):""" Run one iteration of the pipeline schedule """iftarget_mbsisnotNoneorlossesisnotNone:raiseRuntimeError("Forward-only schedule does not support loss computation")arg_mbs,kwarg_mbs=self._check_inputs(arg_mbs,kwarg_mbs,target_mbs,losses)ifnotself._stage_initialized:self._initialize_stage(arg_mbs[0],kwarg_mbs[0])# Delay send waitsfwd_sends_to_wait:List[dist.Work]=[]# Run microbatchesforiinrange(self._n_microbatches):withrecord_function(f"Forward {i}"):ops=self._stage.get_fwd_recv_ops(i)works=_sorted_batch_p2p(ops,desc="fwd_recv")forworkinworks.values():work.wait()self._stage.forward_one_chunk(i,arg_mbs[i],kwarg_mbs[i])# type: ignore[index]ops=self._stage.get_fwd_send_ops(i)works=_sorted_batch_p2p(ops,desc="fwd_send")fwd_sends_to_wait.extend(works.values())logger.debug("[%s] Forwarded microbatch %s",self._stage.stage_index,i)# Wait for all forward sends to finish# This should not have performance impact because by the time the first# backward arrives all the forward sends should have been finished.forworkinfwd_sends_to_wait:work.wait()
[docs]classScheduleGPipe(PipelineScheduleSingle):""" The GPipe schedule. Will go through all the microbatches in a fill-drain manner. """def_step_microbatches(self,arg_mbs:Optional[List]=None,kwarg_mbs:Optional[List]=None,target_mbs:Optional[List]=None,losses:Optional[List]=None,):""" Run one iteration of the pipeline schedule with list of microbatches. Will go through all the microbatches according to the GPipe schedule. Args: microbatches: list of microbatch args. """arg_mbs,kwarg_mbs=self._check_inputs(arg_mbs,kwarg_mbs,target_mbs,losses)ifnotself._stage_initialized:self._initialize_stage(arg_mbs[0],kwarg_mbs[0])# Delay send waitsfwd_sends_to_wait:List[dist.Work]=[]# Run microbatchesforiinrange(self._n_microbatches):withrecord_function(f"Forward {i}"):ops=self._stage.get_fwd_recv_ops(i)works=_sorted_batch_p2p(ops,desc="fwd_recv")forworkinworks.values():work.wait()output=self._stage.forward_one_chunk(i,arg_mbs[i],kwarg_mbs[i])# type: ignore[index]ops=self._stage.get_fwd_send_ops(i)works=_sorted_batch_p2p(ops,desc="fwd_send")fwd_sends_to_wait.extend(works.values())logger.debug("[%s] Forwarded microbatch %s",self._stage.stage_index,i)self._maybe_compute_loss(self._stage,output,target_mbs,i)# Wait for all forward sends to finish# This should not have performance impact because by the time the first# backward arrives all the forward sends should have been finished.forworkinfwd_sends_to_wait:work.wait()# No loss function, no need to run backwardifnotself._has_backward:return# Run backward# Delay send waitsbwd_sends_to_wait:List[dist.Work]=[]foriinrange(self._n_microbatches):withrecord_function(f"Backward {i}"):ops=self._stage.get_bwd_recv_ops(i)works=_sorted_batch_p2p(ops,desc="bwd_recv")forworkinworks.values():work.wait()loss=self._maybe_get_loss(self._stage,i)self._stage.backward_one_chunk(i,loss=loss,last_backward=i==self._n_microbatches-1)ops=self._stage.get_bwd_send_ops(i)works=_sorted_batch_p2p(ops,desc="bwd_send")bwd_sends_to_wait.extend(works.values())logger.debug("[%s] Backwarded microbatch %s",self._stage.stage_index,i)# Return losses if there is a container passed inself._update_losses(self._stage,losses)# Wait for all backward sends to finishforworkinbwd_sends_to_wait:work.wait()
[docs]classSchedule1F1B(PipelineScheduleSingle):""" The 1F1B schedule. Will perform one forward and one backward on the microbatches in steady state. """def_step_microbatches(self,arg_mbs:Optional[List]=None,kwarg_mbs:Optional[List]=None,target_mbs:Optional[List]=None,losses:Optional[List]=None,):""" Run one iteration of the pipeline schedule with list of microbatches. Will go through all the microbatches according to the 1F1B schedule. Args: microbatches: list of microbatch args. """arg_mbs,kwarg_mbs=self._check_inputs(arg_mbs,kwarg_mbs,target_mbs,losses)ifnotself._stage_initialized:self._initialize_stage(arg_mbs[0],kwarg_mbs[0])# Last stage has 1 warmup, second-to-last 2 warmups, ...# first stage `num_stages` warmupswarmup_chunks=min(self._n_microbatches,self._num_stages-self._stage.stage_index,)# Chunk countersfwd_mb_index=0bwd_mb_index=0# Warmup phasesend_work=Nonefwd_sends=[]for_inrange(warmup_chunks):# Receive activationsfwd_recvs=self._stage.get_fwd_recv_ops(fwd_mb_index)ifrecv_work:=_batch_p2p(fwd_recvs,desc="fwd_recv"):recv_work.wait()# Computeoutput=self._stage.forward_one_chunk(fwd_mb_index,arg_mbs[fwd_mb_index],kwarg_mbs[fwd_mb_index])# type: ignore[index]# Clear previous chunk's forward sends (hopefully they have well# finished, otherwise, we are heavily communication bound, in which# case it doesn't create a lot of benefit to compute next chunk# eagerly either)ifsend_work:send_work.wait()# Send activationsfwd_sends=self._stage.get_fwd_send_ops(fwd_mb_index)iffwd_mb_index!=warmup_chunks-1:# Safe to firesend_work=_batch_p2p(fwd_sends,desc="fwd_send")# otherwise:# The last foward send is left for fuse with first 1B in 1B1F below# Compute lossself._maybe_compute_loss(self._stage,output,target_mbs,fwd_mb_index)fwd_mb_index+=1# Now we should have send ops left over, to be fused with first 1B of 1B1F phase below.# 1B1F phasewhileTrue:# Don't worry, we have a break inside# We actually do 1B first as the `1B1F` name indicates, so prepare its recv opsbwd_recvs=self._stage.get_bwd_recv_ops(bwd_mb_index)# Now, we need to fire the fwd_sends and bwd_recvs togetheriffuse_work:=_batch_p2p(fwd_sends+bwd_recvs,desc="fwd_send_bwd_recv"):fuse_work.wait()# Backward one chunkloss=self._maybe_get_loss(self._stage,bwd_mb_index)self._stage.backward_one_chunk(bwd_mb_index,loss=loss,last_backward=bwd_mb_index==self._n_microbatches-1,)# Get the bwd send ops, but don't fire, to be fused with the 1F belowbwd_sends=self._stage.get_bwd_send_ops(bwd_mb_index)bwd_mb_index+=1iffwd_mb_index==self._n_microbatches:# We are done with 1B1F, so break with some left-over bwd_sendsbreak# We prepare 1F of the `1B1F`fwd_recvs=self._stage.get_fwd_recv_ops(fwd_mb_index)# Fuse it with bwd_sends aboveiffuse_work:=_batch_p2p(bwd_sends+fwd_recvs,desc="bwd_send_fwd_recv"):fuse_work.wait()# Now do the fwdoutput=self._stage.forward_one_chunk(fwd_mb_index,arg_mbs[fwd_mb_index],kwarg_mbs[fwd_mb_index])# type: ignore[index]# Compute lossself._maybe_compute_loss(self._stage,output,target_mbs,fwd_mb_index)# Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around)fwd_sends=self._stage.get_fwd_send_ops(fwd_mb_index)fwd_mb_index+=1# Remember we still have some bwd_sends left over after the break? Now it is time to fire itsend_work=_batch_p2p(bwd_sends,desc="bwd_send")# Cooldownwhilebwd_mb_index<self._n_microbatches:# prepare bwd recv opsbwd_recvs=self._stage.get_bwd_recv_ops(bwd_mb_index)ifrecv_work:=_batch_p2p(bwd_recvs,desc="bwd_recv"):recv_work.wait()# Backward one chunkloss=self._maybe_get_loss(self._stage,bwd_mb_index)self._stage.backward_one_chunk(bwd_mb_index,loss=loss,last_backward=bwd_mb_index==self._n_microbatches-1,)# Clear previous chunk's backward sends (hopefully they have well finished)ifsend_work:send_work.wait()# Get the bwd send ops, fire itbwd_sends=self._stage.get_bwd_send_ops(bwd_mb_index)send_work=_batch_p2p(bwd_sends,desc="bwd_send")bwd_mb_index+=1# Wait for the last backward send to finishifsend_work:send_work.wait()# Return losses if there is a container passed inself._update_losses(self._stage,losses)
def_add_unshard_reshard(compute_actions:List[Optional[_Action]],max_active_stages:int=3,)->List[_Action]:"""Given a basic schedule involving only compute actions (F,B,W), add UNSHARD/RESHARD actions for FSDP. UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation. RESHARD does the opposite, releasing memory (but doing no commmunication) We abandon the "timestep lock" during lowering max_active_stages controls how many prefetches we allow. It should be measured in mb and tuneable but in practice 3 stages is probably the thing we want? (to account for having one f and one b active, and something else prefetching?) """defnext_stage_indices(count:int,next_actions:List[Optional[_Action]])->List[int]:"""Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute."""seen:Set[int]=set()ret:List[int]=[]forainnext_actions:ifaisnotNoneanda.stage_indexnotinseen:seen.add(a.stage_index)ret.append(a.stage_index)iflen(ret)==count:breakreturnretactive_stages:Set[int]=set()fsdp_aware_actions:List[_Action]=[]def_unshard(stage_index:int):active_stages.add(stage_index)fsdp_aware_actions.append(_Action(stage_index,UNSHARD,None))def_reshard(stage_index:int):active_stages.remove(stage_index)fsdp_aware_actions.append(_Action(stage_index,RESHARD,None))fori,actioninenumerate(compute_actions):ifactionisNone:continue# We prefetch the next N stages we'll see, dropping existing stages to make roomnext_n=next_stage_indices(max_active_stages,compute_actions[i:])# Fetch needs to be ordered correctly, so don't use a setfetch=list(filter(lambdas:snotinactive_stages,next_n))# Unclear what the best policy is for eviction, but we can maintain order so we doevict=list(filter(lambdas:snotinnext_n,active_stages))# logger.debug(# "_add_unshard_reshard Step %d active: %s fetch %s, evict %s",# i,# active_stages,# fetch,# evict,# )forstageinevict:_reshard(stage)forstageinfetch:_unshard(stage)fsdp_aware_actions.append(action)returnfsdp_aware_actionsdef_merge_bw(compute_actions:List[Optional[_Action]],)->List[_Action]:"""Given a basic schedule involving only compute actions (F,I,W), merge adjacent I and W ops into B ops. (note: I = BACKWARD_INPUT, W = BACKWARD_WEIGHT, B = FULL_BACKWARD) B refers to running the whole backward (not separating grad_input and grad_weight), which can be more efficient in some cases. """merged_actions=[]whilecompute_actions:action=compute_actions.pop(0)ifactionisNone:continuewhilelen(compute_actions)and(next_action:=compute_actions[0])isNone:# remove any None actions between 'action' and 'next_action'compute_actions.pop(0)if(action.computation_type==BACKWARD_INPUTandnext_actionisnotNoneandnext_action.computation_type==BACKWARD_WEIGHTandaction.stage_index==next_action.stage_indexandaction.microbatch_index==next_action.microbatch_index):merged_actions.append(_Action(action.stage_index,FULL_BACKWARD,action.microbatch_index))compute_actions.pop(0)else:merged_actions.append(action)returnmerged_actionsdef_add_send_recv(compute_actions:Dict[int,List[_Action]],stage_to_rank:Callable[[int],int],num_stages:int,)->Dict[int,List[_Action]]:comm_actions:Dict[int,List[_Action]]={rank:[]forrankincompute_actions}prev_actions:Dict[int,Set[_Action]]={rank:set()forrankincompute_actions}def_has_comms(action:_Action)->bool:ifaction.computation_type==F:returnaction.stage_index!=num_stages-1andstage_to_rank(action.stage_index+1)!=stage_to_rank(action.stage_index)elifaction.computation_typein(BACKWARD_INPUT,FULL_BACKWARD):returnaction.stage_index!=0andstage_to_rank(action.stage_index-1)!=stage_to_rank(action.stage_index)returnFalsedef_get_comms(action:_Action)->Tuple[_Action,_Action]:assert_has_comms(action),f"{action} is not a valid comm action"stage_idx=action.stage_indexctype=action.computation_typemb_idx=action.microbatch_indexsend=_Action(stage_idx,SEND_Fifctype==FelseSEND_B,mb_idx)recv_stage_idx=stage_idx+1ifctype==Felsestage_idx-1recv=_Action(recv_stage_idx,RECV_Fifctype==FelseRECV_B,mb_idx)returnsend,recvdef_ready_to_schedule(action:Optional[_Action],prev_actions:Set[_Action])->bool:"""We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place. This helps ensure a sane (non-hanging) ordering of sends and recvs. But it also means we might not be able to schedule our next compute action yet. """ifactionisNone:returnTrueelifaction.computation_type==Fandnotaction.stage_index==0:if(_Action(action.stage_index,RECV_F,action.microbatch_index)inprev_actions):returnTrueelif(_Action(action.stage_index-1,F,action.microbatch_index)inprev_actions):returnTruereturnFalseelif(action.computation_typein(BACKWARD_INPUT,FULL_BACKWARD)andnotaction.stage_index==num_stages-1):if(_Action(action.stage_index,RECV_B,action.microbatch_index)inprev_actions):returnTrueelif(_Action(action.stage_index+1,BACKWARD_INPUT,action.microbatch_index)inprev_actions):returnTrueelif(_Action(action.stage_index+1,FULL_BACKWARD,action.microbatch_index)inprev_actions):returnTruereturnFalseelse:returnTruewhilecompute_actions:progress=False# go in order of ranks even if dict keys aren't orderedforrankinsorted(compute_actions):assert(len(compute_actions[rank])>0),f"{rank=}, {len(compute_actions[rank])=}"action=compute_actions[rank][0]ifnot_ready_to_schedule(action,prev_actions[rank]):continueifactionisnotNone:comm_actions[rank].append(action)prev_actions[rank].add(action)if_has_comms(action):send,recv=_get_comms(action)# TODO we can avoid send/recv if the 2 stages are on the same rank.# should we avoid that in the runtime or here?comm_actions[rank].append(send)prev_actions[rank].add(send)comm_actions[stage_to_rank(recv.stage_index)].append(recv)prev_actions[stage_to_rank(recv.stage_index)].add(recv)compute_actions[rank].pop(0)iflen(compute_actions[rank])==0:delcompute_actions[rank]progress=Trueassertprogress,"Malformed compute schedule, can't schedule sends/recvs"returncomm_actionsdef_validate_schedule(actions:Dict[int,List[Optional[_Action]]],pp_group_size:int,num_stages:int,num_microbatches:int,):assert(len(actions)==pp_group_size),f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}"forrankinrange(pp_group_size):assertrankinactions,f"Schedule is missing actions for rank {rank}"# We will count all the actions per stage and ensure they happen in a valid order# (e.g. F before (B, I) before W for a given microbatch)stage_actions:Dict[int,Dict[_ComputationType,Set]]={stage_id:{F:set(),B:set(),I:set(),W:set(),}forstage_idinrange(num_stages)}forrankinactions:foractioninactions[rank]:ifactionisNone:continueassertisinstance(action,_Action),f"Got an invalid action: {action}, expected instance of _Action"s_id=action.stage_indexctype=action.computation_typemb_id=action.microbatch_indexifctype==F:stage_actions[s_id][F].add(mb_id)elifctype==B:assert(mb_idinstage_actions[s_id][F]),f"Running Full Backward for stage {s_id}, microbatch {mb_id} without first running Forward"stage_actions[s_id][B].add(mb_id)elifctype==I:assert(mb_idinstage_actions[s_id][F]),f"Running Backward Input for stage {s_id}, microbatch {mb_id} without first running Forward"stage_actions[s_id][I].add(mb_id)elifctype==W:assert(mb_idinstage_actions[s_id][I]),f"Running Backward Weight for stage {s_id}, microbatch {mb_id} without first running Backward Input"stage_actions[s_id][W].add(mb_id)fors_idinstage_actions:f_mb=len(stage_actions[s_id][F])b_mb=len(stage_actions[s_id][B])i_mb=len(stage_actions[s_id][I])w_mb=len(stage_actions[s_id][W])assert(f_mb==num_microbatches),f"Got {f_mb}{F} microbatches for stage {s_id}, expected {num_microbatches}"assert(b_mb+(i_mb+w_mb)//2==num_microbatches),f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \ but got B={b_mb}, I={i_mb}, W={w_mb}"
[docs]classPipelineScheduleMulti(_PipelineSchedule):""" Base class for multi-stage schedules. Implements the `step` method. """def__init__(self,stages:List[_PipelineStageBase],n_microbatches:int,loss_fn:Optional[Callable]=None,args_chunk_spec:Optional[Tuple[TensorChunkSpec,...]]=None,kwargs_chunk_spec:Optional[Dict[str,TensorChunkSpec]]=None,output_merge_spec:Optional[Union[Dict[str,Any],Tuple[Any]]]=None,stage_index_to_group_rank:Optional[Dict[int,int]]=None,use_full_backward:Optional[bool]=None,):# Init parentsuper().__init__(n_microbatches=n_microbatches,loss_fn=loss_fn,args_chunk_spec=args_chunk_spec,kwargs_chunk_spec=kwargs_chunk_spec,output_merge_spec=output_merge_spec,)# Self attributesself._stages=stagesself._num_stages=stages[0].num_stagesself.pp_group_size=stages[0].group_sizeself.rank=stages[0].group_rank# Set the pipeline stage statesifstage_index_to_group_rankisnotNone:forstageinself._stages:stage.stage_index_to_group_rank=stage_index_to_group_rankself.stage_index_to_group_rank=stages[0].stage_index_to_group_rank# Set the same has_backward flag for stage objectforstageinself._stages:stage.has_backward=self._has_backwardself._stages_initialized=False# avoid putting a reference to 'self' inside the lambda, it creates a ref cyclehas_loss:bool=self._loss_fnisnotNoneself._should_compute_loss=lambdastage:stage.is_lastandhas_loss# This will be set during init of derived schedulesself.pipeline_order:Dict[int,List[Optional[_Action]]]={}ifuse_full_backwardisnotNone:logger.warning("Deprecation warning: 'use_full_backward' is no longer supported. ""Simply stop passing it, and everything should still work fine.")def_initialize_stages(self,args:Tuple[Any,...],kwargs):# may be 'none' value (if this stage sends its output shapes to the next stage via P2P)# or real value (if this stage and next stage are on the same device)next_stage_args:Tuple[Any,...]=tuple()forstageinself._stages:ifstage.is_first:next_stage_args=stage._prepare_forward_infra(self._n_microbatches,args,kwargs)else:next_stage_args=stage._prepare_forward_infra(self._n_microbatches,next_stage_args,kwargs)ifself._has_backward:stage._prepare_backward_infra(self._n_microbatches)self._stages_initialized=Truedef_dump_csv(self,filename):"""Dump a CSV representation of the schedule into a file with the provided filename."""withopen(filename,"w",newline="")ascsvfile:writer=csv.writer(csvfile)forrankinself.pipeline_order:writer.writerow(self.pipeline_order[rank])def_load_csv(self,filename,format="compute_only"):"""Load a CSV representation of the schedule from a file with the provided filename. This API will most likely get renamed/refactored so is marked as internal for now. format must be "compute_only" for PipelineScheduleMulti """assertformat=="compute_only"withopen(filename,newline="")ascsvfile:reader=csv.reader(csvfile)forrank,rowinenumerate(reader):self.pipeline_order[rank]=[_Action.from_str(s)forsinrow]_validate_schedule(self.pipeline_order,self.pp_group_size,self._num_stages,self._n_microbatches,)
[docs]defstep(self,*args,target=None,losses:Optional[List]=None,**kwargs):""" Run one iteration of the pipeline schedule with *whole-batch* input. Will chunk the input into microbatches automatically, and go through the microbatches according to the schedule implementation. args: positional arguments to the model (as in non-pipeline case). kwargs: keyword arguments to the model (as in non-pipeline case). target: target for the loss function. losses: a list to store the losses for each microbatch. """# Clean per iterationforstageinself._stages:stage.clear_runtime_states()# Split inputs into microbatchesargs_split,kwargs_split=self._split_inputs(args,kwargs)# Split target into microbatchesiftargetisnotNone:targets_split=list(torch.tensor_split(target,self._n_microbatches))else:targets_split=None# Run microbatchesself._step_microbatches(args_split,kwargs_split,targets_split,losses)# Return merged results per original formatforstageinself._stages:ifstage.is_last:returnself._merge_outputs(stage.output_chunks)# Does not contain the last stagereturnNone
def_step_microbatches(self,arg_mbs:Optional[List]=None,kwarg_mbs:Optional[List]=None,target_mbs:Optional[List]=None,losses:Optional[List]=None,):""" Operate on the microbatches for looped schedules (multiple stages on each rank). TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does not support models with skip connections. """arg_mbs,kwarg_mbs=self._check_inputs(arg_mbs,kwarg_mbs,target_mbs,losses)ifnotself._stages_initialized:self._initialize_stages(arg_mbs[0],kwarg_mbs[0])# Based on the plan in Step 1 created in __init__:# 2. Perform communication based on the pipeline_orderstage_index_to_stage:Dict[int,_PipelineStageBase]={stage.stage_index:stageforstageinself._stages}# determine prev_rank and next_rank based on which ranks are next to# the stages in the pipeline_orderall_prev_ranks:Set[int]=set()all_next_ranks:Set[int]=set()forstage_indexinstage_index_to_stage.keys():# TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections)ifstage_index>0:all_prev_ranks.add(self.stage_index_to_group_rank[stage_index-1])ifstage_index<self._num_stages-1:all_next_ranks.add(self.stage_index_to_group_rank[stage_index+1])# count either full_backward or backward_weight together, to determine when to sync DP gradsbackward_counter:Counter[int]=Counter()fortime_step,actioninenumerate(self.pipeline_order[self.rank]):try:ops:List[dist.P2POp]=[]ifactionisnotNone:computation_type=action.computation_typemb_index=action.microbatch_indexstage_index=action.stage_indexassert(mb_indexisnotNone),"All currently supported action types require valid microbatch_index"ifcomputation_type==_ComputationType.FORWARD:# perform forward computationstage=stage_index_to_stage[stage_index]output=stage.forward_one_chunk(mb_index,arg_mbs[mb_index],kwarg_mbs[mb_index])self._maybe_compute_loss(stage,output,target_mbs,mb_index)ops.extend(stage.get_fwd_send_ops(mb_index))elifcomputation_type==_ComputationType.FULL_BACKWARD:# perform backward computationstage=stage_index_to_stage[stage_index]loss=self._maybe_get_loss(stage,mb_index)backward_counter[stage_index]+=1stage.backward_one_chunk(mb_index,loss=loss,full_backward=True,last_backward=backward_counter[stage_index]==self._n_microbatches,)ops.extend(stage.get_bwd_send_ops(mb_index))elifcomputation_type==_ComputationType.BACKWARD_INPUT:# perform backward computationstage=stage_index_to_stage[stage_index]loss=self._maybe_get_loss(stage,mb_index)stage.backward_one_chunk(mb_index,loss=loss,full_backward=False,last_backward=False,)ops.extend(stage.get_bwd_send_ops(mb_index))elifcomputation_type==_ComputationType.BACKWARD_WEIGHT:# perform weight updatestage=stage_index_to_stage[stage_index]backward_counter[stage_index]+=1stage.backward_weight_one_chunk(mb_index,last_backward=backward_counter[stage_index]==self._n_microbatches,)else:raiseValueError(f"Unknown computation type {computation_type}")# Look at the neighboring ranks for this current timestep and determine whether# this current rank needs to do any recv communicationforprev_rankinall_prev_ranks:prev_rank_ops=self.pipeline_order[prev_rank]prev_rank_action=Noneiftime_step<len(prev_rank_ops):prev_rank_action=prev_rank_ops[time_step]ifprev_rank_actionisnotNone:computation_type=prev_rank_action.computation_typemb_index=prev_rank_action.microbatch_indexstage_index=prev_rank_action.stage_indexassert(mb_indexisnotNone),"All currently supported action types require valid microbatch_index"# Only handle sends for the forward from a previous rankifcomputation_type==_ComputationType.FORWARD:# If not the last stage, then receive fwd activationsifstage_index+1instage_index_to_stage:# TODO: We are assuming that stage will always receive from stage-1# however that is not necessarily true of get_fwd_recv_opsstage=stage_index_to_stage[stage_index+1]ops.extend(stage.get_fwd_recv_ops(mb_index))elifcomputation_typein(FULL_BACKWARD,BACKWARD_INPUT,BACKWARD_WEIGHT,):# Previous rank doing backward has no influence for the current rank forward recvpasselse:raiseValueError(f"Unknown computation type {computation_type}")fornext_rankinall_next_ranks:next_rank_ops=self.pipeline_order[next_rank]next_rank_action=Noneiftime_step<len(next_rank_ops):next_rank_action=next_rank_ops[time_step]ifnext_rank_actionisnotNone:computation_type=next_rank_action.computation_typemb_index=next_rank_action.microbatch_indexstage_index=next_rank_action.stage_indexassert(mb_indexisnotNone),"All currently supported action types require valid microbatch_index"# Only handle receives for the backwards from a next rankifcomputation_typein(FORWARD,BACKWARD_WEIGHT):# Next rank doing forward or weight update has no influence for the current rank backward recvpasselifcomputation_typein(BACKWARD_INPUT,FULL_BACKWARD):# If not the first stage, then receive bwd gradientsifstage_index-1instage_index_to_stage:# TODO: We are assuming that stage will always receive from stage+1# however that is not necessarily true of get_bwd_recv_opsstage=stage_index_to_stage[stage_index-1]ops.extend(stage.get_bwd_recv_ops(mb_index))else:raiseValueError(f"Unknown computation type {computation_type}")# do the communicationifops:_batch_p2p(ops).wait()exceptExceptionase:logger.error("[Rank %s] pipeline schedule %s caught the following exception \ at time_step %s when running action %s",self.rank,self.__class__.__name__,time_step,action,)logger.error("%s",_format_pipeline_order(self.pipeline_order,error_step_number=time_step),)raisee# Return losses if there is a container passed inself._update_losses(self._stages,losses)
class_PipelineScheduleRuntime(PipelineScheduleMulti):""" Provides a simple runtime that requires a 'schedule IR' including specified communication operations. Can be instantiated directly by creating _PipelineScheduleRuntime and calling load_csv, or can be subclassed and the subclass can be responsible for creating a schedule IR. """def_load_actions(self,actions:Dict[int,List[Optional[_Action]]],format:str="compute_only",):""" Given an in-memory representation for a simple compute-only schedule, lower it to a complex schedule including communication actions. Stores the schedule in self, and must be called before running step_mo() """assert(self.stage_index_to_group_rankisnotNone),"stage_index_to_group_rank is required for PipelineScheduleRuntime"self.pipeline_order_with_comms:Dict[int,List[_Action]]={}ifformat=="compute_comms":forrankinactions:self.pipeline_order_with_comms[rank]=[]foractioninactions[rank]:assertactionisnotNoneself.pipeline_order_with_comms[rank].append(action)# TODO what level of validation should we offer for compute+comms schedule?elifformat=="compute_only":# Perform schedule loweringforrankinactions:self.pipeline_order_with_comms[rank]=_add_unshard_reshard(actions[rank])self.pipeline_order_with_comms=_add_send_recv(self.pipeline_order_with_comms,stage_to_rank=lambdas:self.stage_index_to_group_rank[s],num_stages=self._num_stages,)else:raiseNotImplementedError(f"{format=} is not implemented")def_load_csv(self,filename:str,format:str="compute_only"):"""Loads a csv in simple format and then lowers it to include comunication actions format must be either "compute_only" or "compute_comms". If compute_only, the lowering passes will automatically be run to generate a compute_comms schedule. """ifformat=="compute_only":# this will populate self.pipeline_ordersuper()._load_csv(filename)# this will populate self.pipeline_order_with_commsself._load_actions(self.pipeline_order)elifformat=="compute_comms":actions={}withopen(filename,newline="")ascsvfile:reader=csv.reader(csvfile)forrank,rowinenumerate(reader):actions[rank]=[_Action.from_str(s)forsinrow]self._load_actions(actions,format=format)else:raiseNotImplementedError(f"{format=} is not implemented")def_dump_csv(self,filename:str):"""Dump a CSV representation of the compute + comms schedule into a file with the provided filename."""# TODO should there be an option to dump the compute_only schedule from PipelineScheduleRuntime? It's possible# that it does not exist if it was created from a compute_comms schedule.assert(self.pipeline_order_with_commsisnotNone),"Must initialize compute_comms schedule before dump_csv"withopen(filename,"w",newline="")ascsvfile:writer=csv.writer(csvfile)forrankinself.pipeline_order_with_comms:writer.writerow(self.pipeline_order_with_comms[rank])def_simulate(self):return_simulate_comms_compute(self.pipeline_order_with_comms,lambdas:self.stage_index_to_group_rank[s],self._num_stages,)def_step_microbatches(self,arg_mbs:Optional[List]=None,kwarg_mbs:Optional[List]=None,target_mbs:Optional[List]=None,losses:Optional[List]=None,):""" Operate on the microbatches for looped schedules (multiple stages on each rank). TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does not support models with skip connections. """arg_mbs,kwarg_mbs=self._check_inputs(arg_mbs,kwarg_mbs,target_mbs,losses)ifnotself._stages_initialized:self._initialize_stages(arg_mbs[0],kwarg_mbs[0])# Based on the plan in Step 1 created in __init__:# 2. Perform communication based on the pipeline_orderstage_index_to_stage:Dict[int,_PipelineStageBase]={stage.stage_index:stageforstageinself._stages}assert(self.pipeline_order_with_commsisnotNone),"Must call _load_actions() before calling _step_microbatches()"# recv ops indexed by (stage_idx, mb_idx) need to be waited on before usebwd_recv_ops:Dict[Tuple[int,int],Work]={}fwd_recv_ops:Dict[Tuple[int,int],Work]={}# send ops should be waited on before step() exists, mainly for hygeinesend_ops:List[Work]=[]# we track which stages are 'active' when used with FSDP, and wait on unshard ops before computing on stagesunshard_ops:Dict[int,UnshardHandle]={}unsharded_stages=set()def_assert_unsharded(stage_idx:int):"""If an unshard is active for `stage_idx`, wait() it and mark `stage_idx` unshared."""ifstage_idxinunshard_ops:unshard_ops[stage_idx].wait()delunshard_ops[stage_idx]unsharded_stages.add(stage_idx)assert(stage_idxinunsharded_stages),f"Attempted to compute on sharded {stage_idx=}"# count either full_backward or backward_weight together, to determine when to sync DP gradsbackward_counter:Counter[int]=Counter()fortime_step,actioninenumerate(self.pipeline_order_with_comms[self.rank]):try:comp_type=action.computation_typemb_index:int=(action.microbatch_indexifaction.microbatch_indexisnotNoneelse-1)assertmb_index>=0orcomp_typein(UNSHARD,RESHARD,),f"{action=} missing mb_index"stage_idx=action.stage_indexstage=stage_index_to_stage[stage_idx]stage_uses_fsdp=isinstance(stage.submod,FSDPModule)# see [Note: V-schedule special case]is_next_stage_on_this_rank=stage_idx+1instage_index_to_stageis_prev_stage_on_this_rank=stage_idx-1instage_index_to_stagelogger.debug("_PipelineScheduleRuntime running time_step %d, action %s",time_step,action,)# TODO(whc) it's not actually safe to use _batch_p2p here in the uncommon case the model has skip-connections,# since we do not want to batch up ops between more than a pair of ranks. _sorted_batch_p2p would be# safe to use instead.# However, I was wondering if I should avoid calling batched operators at all in the case that there is# only one operator per batch. I could iterate through the 'fwd_send_ops' one by one and run them.ifcomp_type==SEND_F:send_ops.append(_batch_p2p(stage.get_fwd_send_ops(mb_index)))elifcomp_type==SEND_B:send_ops.append(_batch_p2p(stage.get_bwd_send_ops(mb_index)))elifcomp_type==RECV_F:assert(stage_idx,mb_index,)notinfwd_recv_ops,"Recv twice for {stage_idx=} {mb_index=} without executing forward"fwd_recv_ops[(stage_idx,mb_index)]=_batch_p2p(stage.get_fwd_recv_ops(mb_index))elifcomp_type==RECV_B:assert(stage_idx,mb_index,)notinbwd_recv_ops,"Recv twice for {stage_idx=} {mb_index=} without executing backward"bwd_recv_ops[(stage_idx,mb_index)]=_batch_p2p(stage.get_bwd_recv_ops(mb_index))elifcomp_type==UNSHARD:ifstage_uses_fsdp:assert(stage_idxnotinunsharded_stagesandstage_idxnotinunshard_ops),f"Unsharding the same {stage_idx=} twice"unshard_ops[stage_idx]=stage.submod.unshard(async_op=True)# type: ignore[operator]elifcomp_type==RESHARD:ifstage_uses_fsdp:assert(stage_idxinunsharded_stages),f"Resharding {stage_idx=} without unsharding"assert(stage_idxnotinunshard_ops),f"Resharding {stage_idx=} before finishing unshard"stage.submod.reshard()# type: ignore[operator]elifcomp_type==FORWARD:ifstage_uses_fsdp:_assert_unsharded(stage_idx)if(notstage.is_first# no recv op expected for V-schedule special case (see [Note: V-schedule special case])andnotis_prev_stage_on_this_rank):assert(stage_idx,mb_index,)infwd_recv_ops,f"Computing {action=} before receiving input"fwd_recv_ops.pop((stage_idx,mb_index)).wait()output=stage.forward_one_chunk(mb_index,arg_mbs[mb_index],kwarg_mbs[mb_index])self._maybe_compute_loss(stage,output,target_mbs,mb_index)# SEND/RECV op are avoided for special case with 2 adjacent stages on same rank# see [Note: V-schedule special case]ifis_next_stage_on_this_rank:stage_index_to_stage[stage_idx+1].set_local_fwd_input(output,mb_index)elifcomp_type==FULL_BACKWARD:ifstage_uses_fsdp:_assert_unsharded(stage_idx)if(notstage.is_last# no recv op expected for V-schedule special case (see [Note: V-schedule special case])andnotis_next_stage_on_this_rank):assert(stage_idx,mb_index,)inbwd_recv_ops,(f"Attempted to run compute {action=} before receiving input")bwd_recv_ops.pop((stage_idx,mb_index)).wait()loss=self._maybe_get_loss(stage,mb_index)backward_counter[stage_idx]+=1stage.backward_one_chunk(mb_index,loss=loss,full_backward=True,last_backward=backward_counter[stage_idx]==self._n_microbatches,)# SEND/RECV op are avoided for special case with 2 adjacent stages on same rank# see [Note: V-schedule special case]ifis_prev_stage_on_this_rank:stage_index_to_stage[stage_idx-1].set_local_bwd_input(stage.get_local_bwd_output(mb_index),mb_index)elifcomp_type==BACKWARD_INPUT:ifstage_uses_fsdp:_assert_unsharded(stage_idx)ifnotstage.is_lastandnotis_next_stage_on_this_rank:assert(stage_idx,mb_index,)inbwd_recv_ops,(f"Attempted to run compute {action=} before receiving input")bwd_recv_ops.pop((stage_idx,mb_index)).wait()loss=self._maybe_get_loss(stage,mb_index)stage.backward_one_chunk(mb_index,loss=loss,full_backward=False,last_backward=False,)# SEND/RECV op are avoided for special case with 2 adjacent stages on same rank# see [Note: V-schedule special case]ifis_prev_stage_on_this_rank:stage_index_to_stage[stage_idx-1].set_local_bwd_input(stage.get_local_bwd_output(mb_index),mb_index)elifcomp_type==BACKWARD_WEIGHT:ifstage_uses_fsdp:_assert_unsharded(stage_idx)backward_counter[stage_idx]+=1stage.backward_weight_one_chunk(mb_index,last_backward=backward_counter[stage_idx]==self._n_microbatches,)else:raiseValueError(f"{action=} is unknown or unsupported")exceptExceptionase:logger.error("_PipelineScheduleRuntime caught exception at step %s when running action %s. Full Schedule:",time_step,action,)# TODO(whc) what is the best practice for printing a multiline log?# logger will split it into multiple log lines, but this makes it hard to read (too wide)print(_format_pipeline_order(self.pipeline_order_with_comms,error_step_number=time_step))# type: ignore[arg-type]raisee# Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for themwhilelen(send_ops):send_ops.pop().wait()assertlen(unshard_ops)==0,"Unused unshard operations"# Return losses if there is a container passed inself._update_losses(self._stages,losses)
[docs]classScheduleLoopedBFS(PipelineScheduleMulti):""" Breadth-First Pipeline Parallelism. See https://arxiv.org/abs/2211.05953 for details. Simliar to Interleaved 1F1B, Looped BFS supports multiple stages per rank. What is different is that when microbatches are ready for multiple local stages, Loops BFS will prioritizes the earlier stage, running all available microbatches at once. """def__init__(self,stages:List[_PipelineStageBase],n_microbatches:int,loss_fn:Optional[Callable]=None,output_merge_spec:Optional[Union[Dict[str,Any],Tuple[Any]]]=None,):super().__init__(stages=stages,n_microbatches=n_microbatches,loss_fn=loss_fn,output_merge_spec=output_merge_spec,)# 1. Create the pipeline_order (all ranks do this calculation)# This will be used to keep track of the current state of the entire pipeline# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]self.pipeline_order:Dict[int,List[Optional[_Action]]]={}# ========================================================================forrankinrange(self.pp_group_size):rank_ops=self._calculate_single_rank_operations(rank)self.pipeline_order[rank]=rank_opsdef_calculate_single_rank_operations(self,rank):n_local_stages=len(self._stages)stage_indices=range(rank,self.pp_group_size*n_local_stages,self.pp_group_size)# Store the list of operations used for that rank# Pre-padding, rank starts with no-ops based on the warmup.rank_ops:List[Optional[_Action]]=[Nonefor_inrange(rank)]forstage_indexinstage_indices:rank_ops.extend(_Action(stage_index,_ComputationType.FORWARD,mb_index)formb_indexinrange(self._n_microbatches))# wait for the first backward to trickle up# which is 2 for every hop awaypost_warmup_ops=2*(self.pp_group_size-1-rank)rank_ops.extend([None]*post_warmup_ops)forstage_indexinreversed(stage_indices):rank_ops.extend(_Action(stage_index,_ComputationType.FULL_BACKWARD,mb_index)formb_indexinreversed(range(self._n_microbatches)))returnrank_ops
def_get_1f1b_rank_ops(n_local_stages,pp_group_size,warmup_ops,fwd_bwd_ops,cooldown_ops,rank,forward_stage_index,backward_stage_index,num_1f1b_microbatches=0,enable_zero_bubble=False,):# All stages start with handling microbatch 0fwd_stage_mb_index:Dict[int,int]=defaultdict(int)bwd_stage_mb_index:Dict[int,int]=defaultdict(int)weight_stage_mb_index:Dict[int,int]=defaultdict(int)# Store the list of operations used for that rank# Pre-padding, rank starts with no-ops based on the warmup.rank_ops:List[Optional[_Action]]=[Nonefor_inrange(rank)]# These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup# when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.# Formula:# pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward# post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding)# earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)]# warmup_ops = calculated abovepost_warmup_ops=(n_local_stages*pp_group_size+2*(pp_group_size-1-rank))-(warmup_ops+rank)ifenable_zero_bubble:post_warmup_ops=pp_group_size-rank-1total_ops=warmup_ops+fwd_bwd_ops+cooldown_opsbackward_op_ids=[]weight_op_count=0FULL_BACKWARD_OR_BACKWARD_INPUT=(BACKWARD_INPUTifenable_zero_bubbleelseFULL_BACKWARD)foropinrange(total_ops):# Warmup phaseifop<warmup_ops:fwd_stage_index=forward_stage_index(op)# This will assign the current microbatch index and update it as wellfwd_stage_mb_index[fwd_stage_index]=(mb_index:=fwd_stage_mb_index[fwd_stage_index])+1rank_ops.append(_Action(fwd_stage_index,_ComputationType.FORWARD,mb_index))ifop==warmup_ops-1:# This is the last step in the warmup phase, so we need to wait for the backward to trickle back uprank_ops.extend([None]*post_warmup_ops)# 1F1B Phase (forward and backward)elifwarmup_ops<=op<warmup_ops+fwd_bwd_ops:fwd_stage_index=forward_stage_index(op)fwd_stage_mb_index[fwd_stage_index]=(fwd_mb_index:=fwd_stage_mb_index[fwd_stage_index])+1rank_ops.append(_Action(fwd_stage_index,_ComputationType.FORWARD,fwd_mb_index))bwd_stage_index=backward_stage_index(op)bwd_stage_mb_index[bwd_stage_index]=(bwd_mb_index:=bwd_stage_mb_index[bwd_stage_index])+1rank_ops.append(_Action(bwd_stage_index,FULL_BACKWARD_OR_BACKWARD_INPUT,bwd_mb_index))backward_op_ids.append(op)ifenable_zero_bubbleandop-warmup_ops>=num_1f1b_microbatches:weight_stage_index=backward_stage_index(backward_op_ids[weight_op_count])weight_stage_mb_index[weight_stage_index]=(weight_mb_index:=weight_stage_mb_index[weight_stage_index])+1rank_ops.append(_Action(weight_stage_index,_ComputationType.BACKWARD_WEIGHT,weight_mb_index,))weight_op_count+=1# Cooldown phaseelse:# During cooldown phase, we need steps to align with 1f1b happening in other ranks# TODO: we don't need to always append, after all 1f1b are finished we can stop appending Noneifnotenable_zero_bubble:rank_ops.append(None)bwd_stage_index=backward_stage_index(op)bwd_stage_mb_index[bwd_stage_index]=(bwd_mb_index:=bwd_stage_mb_index[bwd_stage_index])+1rank_ops.append(_Action(bwd_stage_index,FULL_BACKWARD_OR_BACKWARD_INPUT,bwd_mb_index))backward_op_ids.append(op)ifenable_zero_bubbleandop-warmup_ops>=num_1f1b_microbatches:weight_stage_index=backward_stage_index(backward_op_ids[weight_op_count])weight_stage_mb_index[weight_stage_index]=(weight_mb_index:=weight_stage_mb_index[weight_stage_index])+1rank_ops.append(_Action(weight_stage_index,_ComputationType.BACKWARD_WEIGHT,weight_mb_index,))weight_op_count+=1whileenable_zero_bubbleandweight_op_count<len(backward_op_ids):weight_stage_index=backward_stage_index(backward_op_ids[weight_op_count])weight_stage_mb_index[weight_stage_index]=(weight_mb_index:=weight_stage_mb_index[weight_stage_index])+1rank_ops.append(_Action(weight_stage_index,_ComputationType.BACKWARD_WEIGHT,weight_mb_index))weight_op_count+=1returnrank_ops
[docs]classScheduleInterleaved1F1B(PipelineScheduleMulti):""" The Interleaved 1F1B schedule. See https://arxiv.org/pdf/2104.04473 for details. Will perform one forward and one backward on the microbatches in steady state and supports multiple stages per rank. When microbatches are ready for multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch (also called "depth first"). This schedule is mostly similar to the original paper. It differs by being relaxing the requirement of num_microbatch % pp_size == 0. Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and it works as long as n_microbatches % num_rounds is 0. As a few examples, support 1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0. 2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0. """def__init__(self,stages:List[_PipelineStageBase],n_microbatches:int,loss_fn:Optional[Callable]=None,args_chunk_spec:Optional[Tuple[TensorChunkSpec,...]]=None,kwargs_chunk_spec:Optional[Dict[str,TensorChunkSpec]]=None,output_merge_spec:Optional[Union[Dict[str,Any],Tuple[Any]]]=None,):self.pp_group_size=stages[0].group_sizesuper().__init__(stages=stages,n_microbatches=n_microbatches,loss_fn=loss_fn,args_chunk_spec=args_chunk_spec,kwargs_chunk_spec=kwargs_chunk_spec,output_merge_spec=output_merge_spec,)self.n_local_stages=len(stages)self.rank=stages[0].group_rankself.number_of_rounds=max(1,n_microbatches//self.pp_group_size)self.microbatches_per_round=n_microbatches//self.number_of_roundsifn_microbatches%self.number_of_rounds!=0:raiseValueError("Interleaved 1F1B requires the number of microbatches to be a "f"multiple of the number of rounds ({self.number_of_rounds}), "f"but got {n_microbatches}.")# 1. Create the pipeline_order (all ranks do this calculation)# This will be used to keep track of the current state of the entire pipeline# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]self.pipeline_order:Dict[int,List[Optional[_Action]]]={}forrankinrange(self.pp_group_size):rank_ops=self._calculate_single_rank_operations(rank)self.pipeline_order[rank]=rank_opsdef_calculate_single_rank_operations(self,rank)->List[Optional[_Action]]:defget_rank_warmup_ops(rank):# Warms up operations for last stagewarmups_ops_last_stage=(self.n_local_stages-1)*self.microbatches_per_round# Increment warmup operations by 2 for each hop away from the last stagemultiply_factor=2warmup_ops=warmups_ops_last_stage+multiply_factor*((self.pp_group_size-1)-rank)# We cannot have more warmup operations than there are number of microbatches, so cap it therereturnmin(warmup_ops,self._n_microbatches*self.n_local_stages)warmup_ops=get_rank_warmup_ops(rank)microbatch_ops=self.n_local_stages*self._n_microbatches# fwd_bwd_ops should encompass the remaining forwardsfwd_bwd_ops=microbatch_ops-warmup_ops# cooldown_ops should encompass the remaining backwardscooldown_ops=microbatch_ops-fwd_bwd_ops# total ops encompass both forward and backward opstotal_ops=warmup_ops+fwd_bwd_ops+cooldown_ops# warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2logger.debug("rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",rank,warmup_ops,fwd_bwd_ops,cooldown_ops,total_ops,)# Calculates the stage index based on step and pp_group_sizedefforward_stage_index(step):# Get the local index from 0 to n_local_stages-1local_index=(step//self.microbatches_per_round)%self.n_local_stagesreturn(local_index*self.pp_group_size)+rankdefbackward_stage_index(step):local_index=(self.n_local_stages-1-((step-warmup_ops)//self.microbatches_per_round)%self.n_local_stages)return(local_index*self.pp_group_size)+rankreturn_get_1f1b_rank_ops(self.n_local_stages,self.pp_group_size,warmup_ops,fwd_bwd_ops,cooldown_ops,rank,forward_stage_index,backward_stage_index,)
[docs]classScheduleInterleavedZeroBubble(PipelineScheduleMulti):""" The Interleaved Zero Bubble schedule. See https://arxiv.org/pdf/2401.10241 for details. Will perform one forward and one backward on inputs for the microbatches in steady state and supports multiple stages per rank. Uses the backward for weights to fill in the pipeline bubble. In particular this is implementing the ZB1P schedule in the paper. """def__init__(self,stages:List[_PipelineStageBase],n_microbatches:int,loss_fn:Optional[Callable]=None,args_chunk_spec:Optional[Tuple[TensorChunkSpec,...]]=None,kwargs_chunk_spec:Optional[Dict[str,TensorChunkSpec]]=None,output_merge_spec:Optional[Union[Dict[str,Any],Tuple[Any]]]=None,):self.pp_group_size=stages[0].group_sizesuper().__init__(stages=stages,n_microbatches=n_microbatches,loss_fn=loss_fn,args_chunk_spec=args_chunk_spec,kwargs_chunk_spec=kwargs_chunk_spec,output_merge_spec=output_merge_spec,)self.n_local_stages=len(stages)self.rank=stages[0].group_rankself.number_of_rounds=max(1,n_microbatches//self.pp_group_size)self.microbatches_per_round=n_microbatches//self.number_of_roundsifn_microbatches%self.number_of_rounds!=0:raiseValueError("Zero bubble requires the number of microbatches to be a "f"multiple of the number of rounds ({self.number_of_rounds}), "f"but got {n_microbatches}.")# 1. Create the pipeline_order (all ranks do this calculation)# This will be used to keep track of the current state of the entire pipeline# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]self.pipeline_order:Dict[int,List[Optional[_Action]]]={}forrankinrange(self.pp_group_size):rank_ops=self._calculate_single_rank_operations(rank)self.pipeline_order[rank]=rank_ops# This function add bubbles to the generated schedule based on dependencies of actions# Note that the ZB1P schedule will not require bubbles to be manually added and it is# only useful when n_microbatches <= microbatches_per_roundself.pipeline_order=self._add_bubbles_to_actions(self.n_local_stages*self.pp_group_size,)def_calculate_single_rank_operations(self,rank)->List[Optional[_Action]]:defget_rank_warmup_ops(rank):# Warms up operations for last stagewarmups_ops_last_stage=(self.n_local_stages-1)*self.microbatches_per_round# Increment warmup operations by 2 for each hop away from the last stagemultiply_factor=1warmup_ops=warmups_ops_last_stage+multiply_factor*((self.pp_group_size-1)-rank)# We cannot have more warmup operations than there are number of microbatches, so cap it therereturnmin(warmup_ops,self._n_microbatches*self.n_local_stages)warmup_ops=get_rank_warmup_ops(rank)microbatch_ops=self.n_local_stages*self._n_microbatches# fwd_bwd_ops should encompass the remaining forwardsfwd_bwd_ops=microbatch_ops-warmup_ops# cooldown_ops should encompass the remaining backwardscooldown_ops=microbatch_ops-fwd_bwd_ops# total ops encompass both forward and backward opstotal_ops=warmup_ops+fwd_bwd_ops+cooldown_ops# warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2logger.debug("rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",rank,warmup_ops,fwd_bwd_ops,cooldown_ops,total_ops,)# Calculates the stage index based on step and pp_group_sizedefforward_stage_index(step):# Get the local index from 0 to n_local_stages-1local_index=(step//self.microbatches_per_round)%self.n_local_stagesreturn(local_index*self.pp_group_size)+rankdefbackward_stage_index(step):local_index=(self.n_local_stages-1-((step-warmup_ops)//self.microbatches_per_round)%self.n_local_stages)return(local_index*self.pp_group_size)+ranknum_1f1b_microbatches=rankreturn_get_1f1b_rank_ops(self.n_local_stages,self.pp_group_size,warmup_ops,fwd_bwd_ops,cooldown_ops,rank,forward_stage_index,backward_stage_index,num_1f1b_microbatches,enable_zero_bubble=True,)def_add_bubbles_to_actions(self,num_stages_global):actions=self.pipeline_orderdefneed_bubble(stage,op,microbatch,num_stages_global,seen_ops):ifop==_ComputationType.FORWARD:ifstage!=0and(stage-1,op,microbatch)notinseen_ops:returnTrueelifop==_ComputationType.FULL_BACKWARD:ifstage==num_stages_global-1:return(stage,_ComputationType.FORWARD,microbatch)notinseen_opsreturn(stage+1,op,microbatch)notinseen_opsreturnFalseseen_ops:Set[Tuple[int,_ComputationType,int]]=set()result:Dict[int,List[Optional[_Action]]]={}next_pointer:Dict[int,int]={}bubbles_added:Dict[int,int]={}total_bubbles_added=0forrankinrange(self.pp_group_size):result[rank]=[]next_pointer[rank]=0bubbles_added[rank]=0whileTrue:should_stop=Truetemp_seen_ops:Set[Tuple[int,_ComputationType,int]]=set()forrankinrange(self.pp_group_size):timestamp=next_pointer[rank]iftimestamp>=len(actions[rank]):continueshould_stop=Falseifactions[rank][timestamp]isnotNone:temp_action=actions[rank][timestamp]asserttemp_actionisnotNonestage_index,op,microbatch=temp_actionifnotneed_bubble(stage_index,op,microbatch,num_stages_global,seen_ops):result[rank].append(actions[rank][timestamp])ifmicrobatchisnotNone:temp_seen_ops.add((stage_index,op,microbatch))next_pointer[rank]+=1else:result[rank].append(None)bubbles_added[rank]+=1else:next_pointer[rank]+=1result[rank].append(None)seen_ops.update(temp_seen_ops)ifshould_stop:breakiftotal_bubbles_added>0:logger.warning("Non zero bubbles added: total_bubbles_added=%s bubbles_added=%s",total_bubbles_added,bubbles_added,)returnresult
[docs]classScheduleZBVZeroBubble(PipelineScheduleMulti):""" The Zero Bubble schedule (ZBV variant). See https://arxiv.org/pdf/2401.10241 Section 6 for details. This schedules requires exactly two stages per rank. This schedule will perform one forward and one backward on inputs for the microbatches in steady state and supports multiple stages per rank. Uses backward with respect to weights to fill in the pipeline bubble. This ZB-V schedule would have the "zero bubble" property only if time forward == time backward input == time backward weights. In practice, this is not likely true for real models so alternatively a greedy scheduler could be implemented for unequal/unbalanced time. """def__init__(self,stages:List[_PipelineStageBase],n_microbatches:int,loss_fn:Optional[Callable]=None,args_chunk_spec:Optional[Tuple[TensorChunkSpec,...]]=None,kwargs_chunk_spec:Optional[Dict[str,TensorChunkSpec]]=None,output_merge_spec:Optional[Union[Dict[str,Any],Tuple[Any]]]=None,stage_index_to_group_rank:Optional[Dict[int,int]]=None,):self.pp_group_size=stages[0].group_sizesuper().__init__(stages=stages,n_microbatches=n_microbatches,loss_fn=loss_fn,args_chunk_spec=args_chunk_spec,kwargs_chunk_spec=kwargs_chunk_spec,output_merge_spec=output_merge_spec,stage_index_to_group_rank=stage_index_to_group_rank,)self.n_local_stages=len(stages)ifself.n_local_stages!=2:raiseValueError("ZBV requires exactly 2 stages per rank, but got "f"{self.n_local_stages}.")self.rank=stages[0].group_rankself.num_stages=stages[0].num_stages# 1. Create the pipeline_order (all ranks do this calculation)# This will be used to keep track of the current state of the entire pipeline# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]self.pipeline_order:Dict[int,List[Optional[_Action]]]={}forrankinrange(self.pp_group_size):rank_ops=self._calculate_single_rank_operations(rank)self.pipeline_order[rank]=rank_opsdef_calculate_single_rank_operations(self,rank)->List[Optional[_Action]]:# max(2 * self.pp_group_size - 1, ...) ensure the number of microbatches is at least# as large of the number of microbatches needed to fully utilize the pipelinen_micro=max(2*self.pp_group_size-1,self._n_microbatches)rank_ops:List[Optional[_Action]]=[Nonefor_inrange(rank)]# Forward and backward action counts for stage chunk 0 and chunk 1f0_cnt,f1_cnt,b0_cnt,b1_cnt=0,0,0,0# warm-up phasewarmup_n1=2*(self.pp_group_size-rank)-1stage_id_chunk0=rankstage_id_chunk1=self.num_stages-1-rankfor_inrange(warmup_n1):rank_ops.append(_Action(stage_id_chunk0,computation_type=F,microbatch_index=f0_cnt))f0_cnt+=1warmup_n2=rankfor_inrange(warmup_n2):rank_ops.append(_Action(stage_id_chunk1,computation_type=F,microbatch_index=f1_cnt))f1_cnt+=1rank_ops.append(_Action(stage_id_chunk0,computation_type=F,microbatch_index=f0_cnt))f0_cnt+=1warmup_n3=self.pp_group_size-rankfor_inrange(warmup_n3):rank_ops.append(_Action(stage_id_chunk1,computation_type=F,microbatch_index=f1_cnt))f1_cnt+=1rank_ops.append(_Action(stage_id_chunk1,computation_type=I,microbatch_index=b1_cnt))rank_ops.append(_Action(stage_id_chunk1,computation_type=W,microbatch_index=b1_cnt))b1_cnt+=1# stable phasewhilef1_cnt<f0_cntorf0_cnt<n_micro:iff0_cnt<n_micro:rank_ops.append(_Action(stage_id_chunk0,computation_type=F,microbatch_index=f0_cnt))f0_cnt+=1rank_ops.append(_Action(stage_id_chunk0,computation_type=I,microbatch_index=b0_cnt))rank_ops.append(_Action(stage_id_chunk0,computation_type=W,microbatch_index=b0_cnt))b0_cnt+=1rank_ops.append(_Action(stage_id_chunk1,computation_type=F,microbatch_index=f1_cnt))f1_cnt+=1rank_ops.append(_Action(stage_id_chunk1,computation_type=I,microbatch_index=b1_cnt))rank_ops.append(_Action(stage_id_chunk1,computation_type=W,microbatch_index=b1_cnt))b1_cnt+=1# cool-down phasew0_cnt,w1_cnt=b0_cnt,b1_cntcooldown_n1=rankfor_inrange(cooldown_n1):rank_ops.append(_Action(stage_id_chunk0,computation_type=I,microbatch_index=b0_cnt))b0_cnt+=1rank_ops.append(_Action(stage_id_chunk1,computation_type=I,microbatch_index=b1_cnt))b1_cnt+=1cooldown_n2=self.pp_group_size-rankfor_inrange(cooldown_n2):rank_ops.append(_Action(stage_id_chunk0,computation_type=I,microbatch_index=b0_cnt))b0_cnt+=1rank_ops.append(_Action(stage_id_chunk0,computation_type=W,microbatch_index=w0_cnt))w0_cnt+=1whilew1_cnt<b1_cnt:rank_ops.append(_Action(stage_id_chunk1,computation_type=W,microbatch_index=w1_cnt))w1_cnt+=1whilew0_cnt<b0_cnt:rank_ops.append(_Action(stage_id_chunk0,computation_type=W,microbatch_index=w0_cnt))w0_cnt+=1assertw0_cnt==b0_cntandb0_cnt==f0_cntassertw1_cnt==b1_cntandb1_cnt==f1_cnt# We use max() in the n_micro computation above, so we may need to# remove redundant microbatchesrank_ops=[(actionifactionisnotNoneandaction.microbatch_indexisnotNoneandaction.microbatch_index<self._n_microbatcheselseNone)foractioninrank_ops]returnrank_ops
defget_schedule_class(schedule_name:str):""" Maps a schedule name (case insensitive) to its corresponding class object. Args: schedule_name (str): The name of the schedule. """schedule_map={"1F1B":Schedule1F1B,"Interleaved1F1B":ScheduleInterleaved1F1B,"GPipe":ScheduleGPipe,"LoopedBFS":ScheduleLoopedBFS,"InterleavedZeroBubble":ScheduleInterleavedZeroBubble,"PipelineScheduleSingle":PipelineScheduleSingle,"PipelineScheduleMulti":PipelineScheduleMulti,"ZBVZeroBubble":ScheduleZBVZeroBubble,}lowercase_keys={k.lower():kforkinschedule_map.keys()}lowercase_schedule_name=schedule_name.lower()iflowercase_schedule_namenotinlowercase_keys:raiseValueError(f"Unknown schedule name '{schedule_name}'. The valid options are {list(schedule_map.keys())}")returnschedule_map[lowercase_keys[lowercase_schedule_name]]def_simulate_comms_compute(pipeline_order,stage_to_rank:Callable[[int],int],num_stages:int):"""This function dry-run simulates the actions in the schedule from the perspective of all ranks, and flags any deadlocks caused by missing or misordered communications. It also simulates any bubbles in time where a rank can not execute any action due to waiting for unmet dependencies. The total number of simulator steps can be used as a metric for unit tests involving IR optimization passes as reordering and merging of IR can reduce the number of simulated steps. The simulation is not high-fidelity and does not model overlapping of compute and communication, or cuda streams. Future work may be to enhance this and model the compute time, comms overlap, and even memory. """pipeline_order={rank:[aforainpipeline_order[rank]ifaisnotNone]forrankinsorted(pipeline_order)}_schedule:Dict[int,List[_Action|None]]={rank:[]forrankinsorted(pipeline_order)}_prev_ops_rank:Dict[int,Set[_Action]]={rank:set()forrankin_schedule}defadd_to_schedule(rank:int,action:Optional[_Action]):_schedule[rank].append(action)ifactionisnotNone:_prev_ops_rank[rank].add(action)def_ready_to_schedule(action:Optional[_Action])->bool:ifactionisNone:returnTruestage_idx=action.stage_indexprev_ops=_prev_ops_rank[stage_to_rank(stage_idx)]ifaction.computation_type==F:ifaction.stage_index==0:returnTrueelif(_Action(action.stage_index,RECV_F,action.microbatch_index)inprev_ops):returnTrueelif(_Action(action.stage_index-1,F,action.microbatch_index)inprev_ops):returnTruereturnFalseelifaction.computation_typein(BACKWARD_INPUT,FULL_BACKWARD):ifaction.stage_index==num_stages-1:returnTrueif_Action(action.stage_index,RECV_B,action.microbatch_index)inprev_ops:returnTrueif(_Action(action.stage_index+1,BACKWARD_INPUT,action.microbatch_index)inprev_ops):returnTrueif(_Action(action.stage_index+1,FULL_BACKWARD,action.microbatch_index)inprev_ops):returnTruereturnFalseelifaction.computation_type==BACKWARD_WEIGHT:returnTrueelifaction.computation_type==SEND_F:expected_f=_Action(action.stage_index,F,action.microbatch_index)returnexpected_finprev_opselifaction.computation_type==RECV_F:peer_stage_idx=stage_idx-1expected_send=_Action(peer_stage_idx,SEND_F,action.microbatch_index)returnexpected_sendin_prev_ops_rank[stage_to_rank(peer_stage_idx)]elifaction.computation_type==SEND_B:expected_b=_Action(action.stage_index,BACKWARD_INPUT,action.microbatch_index)expected_bw=_Action(action.stage_index,FULL_BACKWARD,action.microbatch_index)returnexpected_binprev_opsorexpected_bwinprev_opselifaction.computation_type==RECV_B:peer_stage_idx=stage_idx+1expected_send=_Action(peer_stage_idx,SEND_B,action.microbatch_index)returnexpected_sendin_prev_ops_rank[stage_to_rank(peer_stage_idx)]else:raiseValueError(f"Unsupported action type {action}")whilepipeline_order:progress=Falseforrankinsorted(pipeline_order):iflen(pipeline_order[rank])==0:continueaction=pipeline_order[rank][0]if_ready_to_schedule(action):ifactionisnotNone:add_to_schedule(rank,action)pipeline_order[rank].pop(0)progress=Trueelse:add_to_schedule(rank,None)foriinsorted(pipeline_order,reverse=True):iflen(pipeline_order[i])==0:delpipeline_order[i]# hacky, but do a second pass to replace any 'none' at this timestep with a real action, if it got unblocked# by one of the later ranksforrankinsorted(pipeline_order):iflen(pipeline_order[rank])==0:continueif_schedule[rank][-1]isnotNone:continueaction=pipeline_order[rank][0]if_ready_to_schedule(action):ifactionisnotNone:_schedule[rank][-1]=action_prev_ops_rank[rank].add(action)pipeline_order[rank].pop(0)foriinsorted(pipeline_order,reverse=True):iflen(pipeline_order[i])==0:delpipeline_order[i]ifnotprogress:print("WIP comms schedule:\n",_format_pipeline_order(_schedule))forrankinpipeline_order:print(f"{rank=} next action= {pipeline_order[rank][0]}")raiseValueError("Schedule is not progressing")return_scheduledef_dump_chrometrace(schedule,filename):""" This function dumps a schedule IR into a chrometrace format so it can be visualized. It is currently very basic and only serves as a graphical alternative to dumping the schedule IR as text. As future work we may extend this to include more accurate heuristics for durations, or let users input durations, add 'flow events' to let the UI show the connection between sends and recvs, and model cuda streams for comm/compute as separate streams on the chrometrace view. """events=[]forrankinsorted(schedule):fortimestep,actioninenumerate(schedule[rank]):ifactionisNone:continueevents.append({"name":str(action),"cat":("computation"ifaction.computation_typein(F,B,W)else"communication"),"ph":"X","pid":rank,"tid":rank,"ts":timestep,"dur":1,})importjsonwithopen(filename,"w")asf:json.dump({"traceEvents":events},f)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.