Source code for torch.distributed.pipelining.schedules
# mypy: allow-untyped-defs# Copyright (c) Meta Platforms, Inc. and affiliatesimportloggingfromabcimportABC,abstractmethodfromcollectionsimportdefaultdictfromenumimportEnumfromtypingimportAny,Callable,Dict,List,NamedTuple,Optional,Tuple,Unionimporttorchimporttorch.distributedasdistfromtorch.profilerimportrecord_functionfrom.microbatchimportmerge_chunks,split_args_kwargs_into_chunks,TensorChunkSpecfrom.stageimport_PipelineStageBase__all__=["PipelineScheduleSingle","PipelineScheduleMulti","Schedule1F1B","ScheduleGPipe","ScheduleInterleaved1F1B","ScheduleLoopedBFS",]logger=logging.getLogger(__name__)class_ComputationType(Enum):FORWARD=1BACKWARD=2def__str__(self):ifself==_ComputationType.FORWARD:return"F"else:return"B"class_Action(NamedTuple):computation_type:_ComputationTypemicrobatch_index:intstage_index:intdef__repr__(self):returnf"{self.computation_type}{self.microbatch_index}_s{self.stage_index}"class_PipelineSchedule(ABC):def__init__(self,n_microbatches:int,loss_fn:Optional[Callable[...,torch.Tensor]]=None,args_chunk_spec:Optional[Tuple[TensorChunkSpec,...]]=None,kwargs_chunk_spec:Optional[Dict[str,TensorChunkSpec]]=None,output_merge_spec:Optional[Union[Dict[str,Any],Tuple[Any]]]=None,):# From argumentsself._n_microbatches=n_microbatchesself._loss_fn=loss_fn# Chunking specification for positional inputs. (default: `None`)self._args_chunk_spec=args_chunk_spec# Chunking specification for keyword inputs. (default: `None`)self._kwargs_chunk_spec=kwargs_chunk_specself._output_merge_spec=output_merge_spec""" # args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs. # They are used to convert batch to microbatches in `step(x)`. See # `TensorChunkSpec` for helper methods for creating them. """# Derivedself._has_backward=self._loss_fnisnotNone# Holds the losses for each microbatch.self._internal_losses:List[torch.Tensor]=[]logger.info(f"Using {self.__class__.__name__}")# noqa: G004def_maybe_compute_loss(self,stage,output,target_mbs,mb_index):ifstage.is_lastandself._has_backward:loss=self._compute_loss(output,target_mbs[mb_index])# type: ignore[index]self._internal_losses.append(loss)def_maybe_get_loss(self,stage,mb_index):valid_index=0<=mb_index<len(self._internal_losses)ifstage.is_lastandself._has_backwardandvalid_index:returnself._internal_losses[mb_index]eliflen(self._internal_losses)!=0andnotvalid_index:raiseRuntimeError(f"Loss for microbatch {mb_index} is not available. "f"Available losses for microbatches: {self._internal_losses}")else:returnNonedef_update_losses(self,stages,losses):""" Update the losses to those in the internal state """# if stages not a list turn into a listifnotisinstance(stages,list):stages=[stages]contains_last_stage=any(stage.is_lastforstageinstages)# Return losses if there is a container passed inifcontains_last_stageandlossesisnotNone:iflen(self._internal_losses)!=self._n_microbatches:raiseRuntimeError(f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}")# Clean external container firstlosses.clear()# Copy internal losses to external containerlosses.extend(self._internal_losses)self._internal_losses.clear()@abstractmethoddef_step_microbatches(self,arg_mbs:Optional[List]=None,kwarg_mbs:Optional[List]=None,target_mbs:Optional[List]=None,losses:Optional[List]=None,):""" Run one iteration of the pipeline schedule with list of microbatches. Will go through all the microbatches according to the schedule implementation. Args: microbatches: list of microbatch args. """raiseNotImplementedError@abstractmethoddefstep(self,*args,target=None,losses:Optional[List]=None,**kwargs):""" Run one iteration of the pipeline schedule with *whole-batch* input. Will chunk the input into microbatches automatically, and go through the microbatches according to the schedule implementation. args: positional arguments to the model (as in non-pipeline case). kwargs: keyword arguments to the model (as in non-pipeline case). target: target for the loss function. losses: a list to store the losses for each microbatch. """raiseNotImplementedErrordef_check_inputs(self,arg_mbs:Optional[List]=None,kwarg_mbs:Optional[List]=None,target_mbs:Optional[List]=None,losses:Optional[List]=None,):""" Pre-process/check inputs """defcheck_type_and_len(mbs,name:str):ifnotisinstance(mbs,list):raiseTypeError(f"{name} must be a list but got a {type(mbs)}")iflen(mbs)!=self._n_microbatches:raiseValueError(f"Expecting {self._n_microbatches}{name} but got {len(mbs)}")ifarg_mbsisnotNone:check_type_and_len(arg_mbs,"arg_mbs")else:arg_mbs=[()]*self._n_microbatchesifkwarg_mbsisnotNone:check_type_and_len(kwarg_mbs,"kwarg_mbs")else:kwarg_mbs=[{}]*self._n_microbatchesiftarget_mbsisnotNone:check_type_and_len(target_mbs,"target_mbs")iflossesisnotNone:ifnotisinstance(losses,list):raiseTypeError(f"losses must be a list but got a {type(losses)}")returnarg_mbs,kwarg_mbsdef_compute_loss(self,output,target):returnself._loss_fn(output,target)# type: ignore[misc]def_split_inputs(self,args:Tuple[Any,...],kwargs:Optional[Dict[str,Any]]=None,):""" Splits a full-batch input into chunks (i.e. microbatches) and returns the chunks """ifargsorkwargs:args_split,kwargs_split=split_args_kwargs_into_chunks(args,kwargs,self._n_microbatches,self._args_chunk_spec,self._kwargs_chunk_spec,)returnargs_split,kwargs_splitelse:# Empty inputs (e.g. when called on middle stages)# Return a list of empty tuples/dicts with matching length as chunksreturn[()]*self._n_microbatches,[{}]*self._n_microbatchesdef_merge_outputs(self,output_chunks:List[Any])->Any:""" Merge output chunks back to a batch state. If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim). """returnmerge_chunks(output_chunks,self._output_merge_spec,)def_batch_p2p(p2p_ops:List[dist.P2POp],desc:Optional[str]=None):""" Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top. """iflen(p2p_ops)==0:returnNonedesc_str=f"{desc}, "ifdescelse""logger.debug(f"batch_p2p {desc_str}{p2p_ops}")# noqa: G004returndist.batch_isend_irecv(p2p_ops).pop()def_sorted_batch_p2p(p2p_ops:List[dist.P2POp],desc:Optional[str]=None)->Dict[int,dist.Work]:""" Sorts the list of P2P ops by the peer rank, and then calls batch_isend_irecv. Return a dictionary of works by peer rank. This function helps us avoid hangs in case of skip connections. """# Arrange p2p_ops by peer rank:# int is the peer rank;# List is the list of ops towards the peerops_by_peer:Dict[int,List[dist.P2POp]]=defaultdict(list)work_by_peer:Dict[int,dist.Work]={}iflen(p2p_ops)==0:returnwork_by_peer# Classify the ops by peer rankforopinp2p_ops:ops_by_peer[op.peer].append(op)# Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs)forpeer,opsinsorted(ops_by_peer.items()):work_by_peer[peer]=_batch_p2p(ops,desc=desc)returnwork_by_peer
[docs]classPipelineScheduleSingle(_PipelineSchedule):""" Base class for single-stage schedules. Implements the `step` method. Derived classes should implement `_step_microbatches`. """def__init__(self,stage:_PipelineStageBase,n_microbatches:int,loss_fn:Optional[Callable]=None,args_chunk_spec:Optional[Tuple[TensorChunkSpec,...]]=None,kwargs_chunk_spec:Optional[Dict[str,TensorChunkSpec]]=None,output_merge_spec:Optional[Union[Dict[str,Any],Tuple[Any]]]=None,):# Init parentsuper().__init__(n_microbatches=n_microbatches,loss_fn=loss_fn,args_chunk_spec=args_chunk_spec,kwargs_chunk_spec=kwargs_chunk_spec,output_merge_spec=output_merge_spec,)# Self attributesself._stage=stageself._num_stages=stage.num_stages# Set the same has_backward flag for stage objectself._stage.has_backward=self._has_backward# TODO: later replace this with lazy shape inference during forward# Prepare forward send/recv infrastructure for stagestage._prepare_forward_infra(n_microbatches)ifself._has_backward:stage._prepare_backward_infra(n_microbatches)
[docs]defstep(self,*args,target=None,losses:Optional[List]=None,**kwargs):""" Run one iteration of the pipeline schedule with *whole-batch* input. Will chunk the input into microbatches automatically, and go through the microbatches according to the schedule implementation. args: positional arguments to the model (as in non-pipeline case). kwargs: keyword arguments to the model (as in non-pipeline case). target: target for the loss function. losses: a list to store the losses for each microbatch. """# Clean per iterationself._stage.clear_runtime_states()# Split inputs into microbatchesargs_split,kwargs_split=self._split_inputs(args,kwargs)# Split target into microbatchesiftargetisnotNone:targets_split=list(torch.tensor_split(target,self._n_microbatches))else:targets_split=None# Run microbatchesself._step_microbatches(args_split,kwargs_split,targets_split,losses)# Return merged results per original formatifself._stage.is_last:returnself._merge_outputs(self._stage.output_chunks)else:returnNone
[docs]classScheduleGPipe(PipelineScheduleSingle):""" The GPipe schedule. Will go through all the microbatches in a fill-drain manner. """def_step_microbatches(self,arg_mbs:Optional[List]=None,kwarg_mbs:Optional[List]=None,target_mbs:Optional[List]=None,losses:Optional[List]=None,):""" Run one iteration of the pipeline schedule with list of microbatches. Will go through all the microbatches according to the GPipe schedule. Args: microbatches: list of microbatch args. """arg_mbs,kwarg_mbs=self._check_inputs(arg_mbs,kwarg_mbs,target_mbs,losses)# Delay send waitsfwd_sends_to_wait:List[dist.Work]=[]# Run microbatchesforiinrange(self._n_microbatches):withrecord_function(f"Forward {i}"):ops=self._stage.get_fwd_recv_ops(i)works=_sorted_batch_p2p(ops,desc="fwd_recv")forworkinworks.values():work.wait()output=self._stage.forward_one_chunk(i,arg_mbs[i],kwarg_mbs[i])# type: ignore[index]ops=self._stage.get_fwd_send_ops(i)works=_sorted_batch_p2p(ops,desc="fwd_send")fwd_sends_to_wait.extend(works.values())logger.debug(f"[{self._stage.stage_index}] Forwarded microbatch {i}"# noqa: G004)self._maybe_compute_loss(self._stage,output,target_mbs,i)# Wait for all forward sends to finish# This should not have performance impact because by the time the first# backward arrives all the forward sends should have been finished.forworkinfwd_sends_to_wait:work.wait()# No loss function, no need to run backwardifnotself._has_backward:return# Run backward# Delay send waitsbwd_sends_to_wait:List[dist.Work]=[]foriinrange(self._n_microbatches):withrecord_function(f"Backward {i}"):ops=self._stage.get_bwd_recv_ops(i)works=_sorted_batch_p2p(ops,desc="bwd_recv")forworkinworks.values():work.wait()loss=self._maybe_get_loss(self._stage,i)self._stage.backward_one_chunk(i,loss=loss)ops=self._stage.get_bwd_send_ops(i)works=_sorted_batch_p2p(ops,desc="bwd_send")bwd_sends_to_wait.extend(works.values())logger.debug(f"[{self._stage.stage_index}] Backwarded microbatch {i}"# noqa: G004)# Return losses if there is a container passed inself._update_losses(self._stage,losses)# Wait for all backward sends to finishforworkinbwd_sends_to_wait:work.wait()
[docs]classSchedule1F1B(PipelineScheduleSingle):""" The 1F1B schedule. Will perform one forward and one backward on the microbatches in steady state. """def_step_microbatches(self,arg_mbs:Optional[List]=None,kwarg_mbs:Optional[List]=None,target_mbs:Optional[List]=None,losses:Optional[List]=None,):""" Run one iteration of the pipeline schedule with list of microbatches. Will go through all the microbatches according to the 1F1B schedule. Args: microbatches: list of microbatch args. """arg_mbs,kwarg_mbs=self._check_inputs(arg_mbs,kwarg_mbs,target_mbs,losses)# Last stage has 1 warmup, second-to-last 2 warmups, ...# first stage `num_stages` warmupswarmup_chunks=min(self._n_microbatches,self._num_stages-self._stage.stage_index,)# Chunk countersfwd_mb_index=0bwd_mb_index=0# Warmup phasesend_work=Nonefwd_sends=[]for_inrange(warmup_chunks):# Receive activationsfwd_recvs=self._stage.get_fwd_recv_ops(fwd_mb_index)ifrecv_work:=_batch_p2p(fwd_recvs,desc="fwd_recv"):recv_work.wait()# Computeoutput=self._stage.forward_one_chunk(fwd_mb_index,arg_mbs[fwd_mb_index],kwarg_mbs[fwd_mb_index])# type: ignore[index]# Clear previous chunk's forward sends (hopefully they have well# finished, otherwise, we are heavily communication bound, in which# case it doesn't create a lot of benefit to compute next chunk# eagerly either)ifsend_work:send_work.wait()# Send activationsfwd_sends=self._stage.get_fwd_send_ops(fwd_mb_index)iffwd_mb_index!=warmup_chunks-1:# Safe to firesend_work=_batch_p2p(fwd_sends,desc="fwd_send")# otherwise:# The last foward send is left for fuse with first 1B in 1B1F below# Compute lossself._maybe_compute_loss(self._stage,output,target_mbs,fwd_mb_index)fwd_mb_index+=1# Now we should have send ops left over, to be fused with first 1B of 1B1F phase below.# 1B1F phasewhileTrue:# Don't worry, we have a break inside# We actually do 1B first as the `1B1F` name indicates, so prepare its recv opsbwd_recvs=self._stage.get_bwd_recv_ops(bwd_mb_index)# Now, we need to fire the fwd_sends and bwd_recvs togetheriffuse_work:=_batch_p2p(fwd_sends+bwd_recvs,desc="fwd_send_bwd_recv"):fuse_work.wait()# Backward one chunkloss=self._maybe_get_loss(self._stage,bwd_mb_index)self._stage.backward_one_chunk(bwd_mb_index,loss=loss)# Get the bwd send ops, but don't fire, to be fused with the 1F belowbwd_sends=self._stage.get_bwd_send_ops(bwd_mb_index)bwd_mb_index+=1iffwd_mb_index==self._n_microbatches:# We are done with 1B1F, so break with some left-over bwd_sendsbreak# We prepare 1F of the `1B1F`fwd_recvs=self._stage.get_fwd_recv_ops(fwd_mb_index)# Fuse it with bwd_sends aboveiffuse_work:=_batch_p2p(bwd_sends+fwd_recvs,desc="bwd_send_fwd_recv"):fuse_work.wait()# Now do the fwdoutput=self._stage.forward_one_chunk(fwd_mb_index,arg_mbs[fwd_mb_index],kwarg_mbs[fwd_mb_index])# type: ignore[index]# Compute lossself._maybe_compute_loss(self._stage,output,target_mbs,fwd_mb_index)# Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around)fwd_sends=self._stage.get_fwd_send_ops(fwd_mb_index)fwd_mb_index+=1# Remember we still have some bwd_sends left over after the break? Now it is time to fire itsend_work=_batch_p2p(bwd_sends,desc="bwd_send")# Cooldownwhilebwd_mb_index<self._n_microbatches:# prepare bwd recv opsbwd_recvs=self._stage.get_bwd_recv_ops(bwd_mb_index)ifrecv_work:=_batch_p2p(bwd_recvs,desc="bwd_recv"):recv_work.wait()# Backward one chunkloss=self._maybe_get_loss(self._stage,bwd_mb_index)self._stage.backward_one_chunk(bwd_mb_index,loss=loss)# Clear previous chunk's backward sends (hopefully they have well finished)ifsend_work:send_work.wait()# Get the bwd send ops, fire itbwd_sends=self._stage.get_bwd_send_ops(bwd_mb_index)send_work=_batch_p2p(bwd_sends,desc="bwd_send")bwd_mb_index+=1# Wait for the last backward send to finishifsend_work:send_work.wait()# Return losses if there is a container passed inself._update_losses(self._stage,losses)
[docs]classPipelineScheduleMulti(_PipelineSchedule):""" Base class for multi-stage schedules. Implements the `step` method. """def__init__(self,stages:List[_PipelineStageBase],n_microbatches:int,loss_fn:Optional[Callable]=None,args_chunk_spec:Optional[Tuple[TensorChunkSpec,...]]=None,kwargs_chunk_spec:Optional[Dict[str,TensorChunkSpec]]=None,output_merge_spec:Optional[Union[Dict[str,Any],Tuple[Any]]]=None,):iflen(stages)<=1:raiseValueError(f"Multi-stage schedule expects at least two stages but got {len(stages)}")# Init parentsuper().__init__(n_microbatches=n_microbatches,loss_fn=loss_fn,args_chunk_spec=args_chunk_spec,kwargs_chunk_spec=kwargs_chunk_spec,output_merge_spec=output_merge_spec,)# Self attributesself._stages=stagesself._num_stages=stages[0].num_stagesself.pp_group_size=stages[0].group_sizeself.rank=stages[0].group_rank# Set the same has_backward flag for stage objectforstageinself._stages:stage.has_backward=self._has_backwardself._should_compute_loss=(lambdastage:stage.is_lastandself._loss_fnisnotNone)# This will be set during init of derived schedulesself.pipeline_order:Dict[int,List[Optional[_Action]]]={}# TODO: later replace this with lazy shape inference during forward# Prepare forward send/recv infrastructure for stageforstageinself._stages:stage._prepare_forward_infra(n_microbatches)ifself._has_backward:stage._prepare_backward_infra(n_microbatches)
[docs]defstep(self,*args,target=None,losses:Optional[List]=None,**kwargs):""" Run one iteration of the pipeline schedule with *whole-batch* input. Will chunk the input into microbatches automatically, and go through the microbatches according to the schedule implementation. args: positional arguments to the model (as in non-pipeline case). kwargs: keyword arguments to the model (as in non-pipeline case). target: target for the loss function. losses: a list to store the losses for each microbatch. """# Clean per iterationforstageinself._stages:stage.clear_runtime_states()# Split inputs into microbatchesargs_split,kwargs_split=self._split_inputs(args,kwargs)# Split target into microbatchesiftargetisnotNone:targets_split=list(torch.tensor_split(target,self._n_microbatches))else:targets_split=None# Run microbatchesself._step_microbatches(args_split,kwargs_split,targets_split,losses)# Return merged results per original formatforstageinself._stages:ifstage.is_last:returnself._merge_outputs(stage.output_chunks)# Does not contain the last stagereturnNone
def_step_microbatches(self,arg_mbs:Optional[List]=None,kwarg_mbs:Optional[List]=None,target_mbs:Optional[List]=None,losses:Optional[List]=None,):""" Operate on the microbatches for looped schedules (multiple stages on each rank). TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does not support models with skip connections. """arg_mbs,kwarg_mbs=self._check_inputs(arg_mbs,kwarg_mbs,target_mbs,losses)# Based on the plan in Step 1 created in __init__:# 2. Perform communication based on the pipeline_orderstage_index_to_stage:Dict[int,_PipelineStageBase]={stage.stage_index:stageforstageinself._stages}prev_rank:int=(self.rank-1)%self.pp_group_sizenext_rank:int=(self.rank+1)%self.pp_group_sizefortime_step,actioninenumerate(self.pipeline_order[self.rank]):prev_rank_ops=self.pipeline_order[prev_rank]next_rank_ops=self.pipeline_order[next_rank]ops:List[dist.P2POp]=[]ifactionisnotNone:computation_type,mb_index,stage_index=actionifcomputation_type==_ComputationType.FORWARD:# perform forward computationstage=stage_index_to_stage[stage_index]output=stage.forward_one_chunk(mb_index,arg_mbs[mb_index],kwarg_mbs[mb_index])self._maybe_compute_loss(stage,output,target_mbs,mb_index)ops.extend(stage.get_fwd_send_ops(mb_index))elifcomputation_type==_ComputationType.BACKWARD:# perform backward computationstage=stage_index_to_stage[stage_index]loss=self._maybe_get_loss(stage,mb_index)stage.backward_one_chunk(mb_index,loss=loss)ops.extend(stage.get_bwd_send_ops(mb_index))else:raiseValueError(f"Unknown computation type {computation_type}")# Look at the neighboring ranks for this current timestep and determine whether# this current rank needs to do any recv communicationprev_rank_action=Noneiftime_step<len(prev_rank_ops):prev_rank_action=prev_rank_ops[time_step]ifprev_rank_actionisnotNone:computation_type,mb_index,stage_index=prev_rank_action# Only handle sends for the forward from a previous rankifcomputation_type==_ComputationType.FORWARD:# If not the last stage, then receive fwd activationsifstage_index!=self._num_stages-1:# TODO: We are assuming that stage will always receive from stage-1# however that is not necessarily true of get_fwd_recv_opsstage=stage_index_to_stage[stage_index+1]ops.extend(stage.get_fwd_recv_ops(mb_index))elifcomputation_type==_ComputationType.BACKWARD:# Previous rank doing backward has no influence for the current rank forward recvpasselse:raiseValueError(f"Unknown computation type {computation_type}")next_rank_action=Noneiftime_step<len(next_rank_ops):next_rank_action=next_rank_ops[time_step]ifnext_rank_actionisnotNone:computation_type,mb_index,stage_index=next_rank_action# Only handle receives for the backwards from a next rankifcomputation_type==_ComputationType.FORWARD:# Next rank doing forward has no influence for the current rank backward recvpasselifcomputation_type==_ComputationType.BACKWARD:# If not the first stage, then receive bwd gradientsifstage_index!=0:# TODO: We are assuming that stage will always receive from stage+1# however that is not necessarily true of get_bwd_recv_opsstage=stage_index_to_stage[stage_index-1]ops.extend(stage.get_bwd_recv_ops(mb_index))else:raiseValueError(f"Unknown computation type {computation_type}")# do the communicationifops:_batch_p2p(ops).wait()# Return losses if there is a container passed inself._update_losses(self._stages,losses)
[docs]classScheduleLoopedBFS(PipelineScheduleMulti):""" Breadth-First Pipeline Parallelism. See https://arxiv.org/abs/2211.05953 for details. Simliar to Interleaved 1F1B, Looped BFS supports multiple stages per rank. What is different is that when microbatches are ready for multiple local stages, Loops BFS will prioritizes the earlier stage, running all available microbatches at once. """def__init__(self,stages:List[_PipelineStageBase],n_microbatches:int,loss_fn:Optional[Callable]=None,output_merge_spec:Optional[Union[Dict[str,Any],Tuple[Any]]]=None,):super().__init__(stages=stages,n_microbatches=n_microbatches,loss_fn=loss_fn,output_merge_spec=output_merge_spec,)# 1. Create the pipeline_order (all ranks do this calculation)# This will be used to keep track of the current state of the entire pipeline# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]self.pipeline_order:Dict[int,List[Optional[_Action]]]={}# ========================================================================forrankinrange(self.pp_group_size):rank_ops=self._calculate_single_rank_operations(rank)self.pipeline_order[rank]=rank_opsdef_calculate_single_rank_operations(self,rank):n_local_stages=len(self._stages)stage_indices=range(rank,self.pp_group_size*n_local_stages,self.pp_group_size)# Store the list of operations used for that rankrank_ops:List[Optional[_Action]]=[]# Pre-padding, rank starts with no-ops based on the warmup.for_inrange(rank):rank_ops.append(None)forstage_indexinstage_indices:formb_indexinrange(self._n_microbatches):rank_ops.append(_Action(_ComputationType.FORWARD,mb_index,stage_index))# wait for the first backward to trickle up# which is 2 for every hop awaypost_warmup_ops=2*(self.pp_group_size-1-rank)rank_ops.extend([None]*post_warmup_ops)forstage_indexinreversed(stage_indices):formb_indexinreversed(range(self._n_microbatches)):rank_ops.append(_Action(_ComputationType.BACKWARD,mb_index,stage_index))returnrank_ops
[docs]classScheduleInterleaved1F1B(PipelineScheduleMulti):""" The Interleaved 1F1B schedule. See https://arxiv.org/pdf/2104.04473 for details. Will perform one forward and one backward on the microbatches in steady state and supports multiple stages per rank. When microbatches are ready for multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch (also called "depth first"). """def__init__(self,stages:List[_PipelineStageBase],n_microbatches:int,loss_fn:Optional[Callable]=None,args_chunk_spec:Optional[Tuple[TensorChunkSpec,...]]=None,kwargs_chunk_spec:Optional[Dict[str,TensorChunkSpec]]=None,output_merge_spec:Optional[Union[Dict[str,Any],Tuple[Any]]]=None,):self.pp_group_size=stages[0].group_size# TODO: is this limitation a must?ifn_microbatches%self.pp_group_size!=0:raiseValueError(f"Interleaved 1F1B schedule requires the number of microbatches ({n_microbatches}) \ to be a multiple of the number of pipeline ranks ({self.pp_group_size}).")super().__init__(stages=stages,n_microbatches=n_microbatches,loss_fn=loss_fn,args_chunk_spec=args_chunk_spec,kwargs_chunk_spec=kwargs_chunk_spec,output_merge_spec=output_merge_spec,)self.n_local_stages=len(stages)self.rank=stages[0].group_rankself.group=stages[0].group# 1. Create the pipeline_order (all ranks do this calculation)# This will be used to keep track of the current state of the entire pipeline# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]self.pipeline_order:Dict[int,List[Optional[_Action]]]={}forrankinrange(self.pp_group_size):rank_ops=self._calculate_single_rank_operations(rank)self.pipeline_order[rank]=rank_opsdef_calculate_single_rank_operations(self,rank)->List[Optional[_Action]]:defget_rank_warmup_ops(rank):# Warms up operations for last stagewarmups_ops_last_stage=(self.n_local_stages-1)*self.pp_group_size# Increment warmup operations by 2 for each hop away from the last stagewarmup_ops=warmups_ops_last_stage+2*((self.pp_group_size-1)-rank)# We cannot have more warmup operations than there are number of microbatches, so cap it therereturnmin(warmup_ops,self._n_microbatches*self.n_local_stages)warmup_ops=get_rank_warmup_ops(rank)microbatch_ops=self.n_local_stages*self._n_microbatches# fwd_bwd_ops should encompass the remaining forwardsfwd_bwd_ops=microbatch_ops-warmup_ops# cooldown_ops should encompass the remaining backwardscooldown_ops=microbatch_ops-fwd_bwd_ops# total ops encompass both forward and backward opstotal_ops=warmup_ops+fwd_bwd_ops+cooldown_ops# warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2logger.debug("rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",rank,warmup_ops,fwd_bwd_ops,cooldown_ops,total_ops,)# Calculates the stage index based on step and pp_group_sizedefforward_stage_index(step):# Get the local index from 0 to n_local_stages-1local_index=(step//self.pp_group_size)%self.n_local_stagesreturn(local_index*self.pp_group_size)+rankdefbackward_stage_index(step):local_index=(self.n_local_stages-1-((step-warmup_ops)//self.pp_group_size)%self.n_local_stages)return(local_index*self.pp_group_size)+rank# Dictionary for tracking {stage index : current microbatch index}# All stages start with handling microbatch 0fwd_stage_mb_index:Dict[int,int]=defaultdict(int)bwd_stage_mb_index:Dict[int,int]=defaultdict(int)# Store the list of operations used for that rankrank_ops:List[Optional[_Action]]=[]# Pre-padding, rank starts with no-ops based on the warmup.for_inrange(rank):rank_ops.append(None)# These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup# when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.# Formula:# pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward# post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding)# earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)]# warmup_ops = calculated abovepost_warmup_ops=(self.n_local_stages*self.pp_group_size+2*(self.pp_group_size-1-rank))-(warmup_ops+rank)foropinrange(total_ops):# Warmup phaseifop<warmup_ops:fwd_stage_index=forward_stage_index(op)# This will assign the current microbatch index and update it as wellfwd_stage_mb_index[fwd_stage_index]=(mb_index:=fwd_stage_mb_index[fwd_stage_index])+1rank_ops.append(_Action(_ComputationType.FORWARD,mb_index,fwd_stage_index))ifop==warmup_ops-1:# This is the last step in the warmup phase, so we need to wait for the backward to trickle back uprank_ops.extend([None]*post_warmup_ops)# 1F1B Phase (forward and backward)elifwarmup_ops<=op<warmup_ops+fwd_bwd_ops:fwd_stage_index=forward_stage_index(op)fwd_stage_mb_index[fwd_stage_index]=(fwd_mb_index:=fwd_stage_mb_index[fwd_stage_index])+1rank_ops.append(_Action(_ComputationType.FORWARD,fwd_mb_index,fwd_stage_index))bwd_stage_index=backward_stage_index(op)bwd_stage_mb_index[bwd_stage_index]=(bwd_mb_index:=bwd_stage_mb_index[bwd_stage_index])+1rank_ops.append(_Action(_ComputationType.BACKWARD,bwd_mb_index,bwd_stage_index))# Cooldown phaseelse:# During cooldown phase, we need steps to align with 1f1b happening in other ranks# TODO: we don't need to always append, after all 1f1b are finished we can stop appending Nonerank_ops.append(None)bwd_stage_index=backward_stage_index(op)bwd_stage_mb_index[bwd_stage_index]=(bwd_mb_index:=bwd_stage_mb_index[bwd_stage_index])+1rank_ops.append(_Action(_ComputationType.BACKWARD,bwd_mb_index,bwd_stage_index))# Post paddingfor_inrange(self.pp_group_size-rank-1):rank_ops.append(None)returnrank_ops
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.