Source code for torch.distributed.pipeline.sync.skip.skippable
# -*- coding: utf-8 -*-# Copyright 2019 Kakao Brain## Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.## This source code is licensed under the BSD license found in the# LICENSE file in the root directory of this source tree."""The user interface to define skip connections."""fromtypingimport(TYPE_CHECKING,Any,Callable,ClassVar,Dict,FrozenSet,Generator,Iterable,List,Optional,Set,Sequence,Tuple,Type,TypeVar,Union,cast,)fromtorchimportTensor,nnfrom..microbatchimportBatchfrom.namespaceimportNamespacefrom.trackerimportcurrent_skip_tracker__all__=["skippable","stash","pop","verify_skippables"]Tensors=Sequence[Tensor]TensorOrTensors=Union[Tensor,Tensors]StashPop=Union["stash","pop"]StashPopGenerator=Generator[StashPop,Optional[Tensor],TensorOrTensors]ifTYPE_CHECKING:# Typechecking: nn.Module is not a GenericSkippableModule=nn.Module[Union[StashPopGenerator,TensorOrTensors]]# type: ignore[type-arg]else:SkippableModule=nn.ModuleT=TypeVar("T",bound="Skippable")classSkippable(nn.Module):"""The base class for skippable modules. Do not use this class directly. Define a subclass by :func:`skippable` instead. """module_cls:ClassVar[Type[SkippableModule]]stashable_names:ClassVar[FrozenSet[str]]poppable_names:ClassVar[FrozenSet[str]]def__init__(self,*args:Any,**kwargs:Any)->None:super().__init__()self.module=self.module_cls(*args,**kwargs)# type: ignore[call-arg]self.namespaces:Dict[str,Namespace]={}def__repr__(self)->str:returnf"@skippable({self.module})"defnamespaced(self,name:str)->Tuple[Namespace,str]:"""Prepends namespace for the given skip name."""ns=self.namespaces.get(name)ns=cast(Namespace,ns)return(ns,name)defstashable(self)->Iterable[Tuple[Namespace,str]]:"""Iterates over namespaced skip names to be stashed."""fornameinself.stashable_names:yieldself.namespaced(name)defpoppable(self)->Iterable[Tuple[Namespace,str]]:"""Iterates over namespaced skip names to be popped."""fornameinself.poppable_names:yieldself.namespaced(name)defisolate(self:T,ns:Namespace,*,only:Optional[Iterable[str]]=None)->T:r"""Isolates a specified subset or the whole set of skip tensors into a namespace. In a single sequential module, skip tensors with the same name are not allowed unless they are isolated by different namespaces. Here's an example using the same name for skip tensors twice. Each pair of ``Layer1`` and ``Layer2`` is isolated with its own namespace ``ns1`` and ``ns2``. There is no conflict anymore:: ns1 = Namespace() ns2 = Namespace() model = nn.Sequential( Layer1().isolate(ns1), Layer1().isolate(ns2), Layer2(), Layer3().isolate(ns2), Layer3().isolate(ns1), ) When `only` parameter is omitted, all skip tensors are isolated. You can isolate a subset of skip tensors by passing `only` parameter:: ns_alice = Namespace() ns_bob = Namespace() model = nn.Sequential( ... StashStashPop().isolate(ns_alice, only=['alice']) \ .isolate(ns_bob, only=['bob']), ... ) Args: ns (Namespace): namespace for isolation Keyword Args: only (iterable of strs): names of specific skip tensors to be isolated (omit this option to isolate all skip tensors declared in this module) Returns: this module itself """names:Iterable[str]ifonlyisNone:names=self.stashable_names|self.poppable_nameselse:names=set(only)fornameinnames:self.namespaces[name]=nsreturnselfdefdispatch(self,input,handle_stash:Callable[[str,Optional[Tensor]],None],handle_pop:Callable[[str],Optional[Tensor]],):"""Dispatches :class:`stash` or :class:`pop` commands generated by the module's ``forward()``. """generator=self.module(input)ifnotisinstance(generator,Generator):# The underlying module returned output without any yield.output=generatorreturnoutputtry:op=next(generator)whileTrue:ifisinstance(op,stash):handle_stash(op.name,op.tensor)op=next(generator)continueifisinstance(op,pop):tensor=handle_pop(op.name)op=generator.send(tensor)continueraiseTypeError("%r is not a command from @skippable"%op)exceptStopIterationasstop:output=stop.args[0]returnoutputdefforward(self,input:Union[List[Any],Tensor])->TensorOrTensors:"""Performs the forward propagation. :class:`stash` or :class:`pop` commands will be handled by portals silently. The portals won't be exposed to users. Raises: RuntimeError: illegal 'stash' or 'pop' is found. """skip_tracker=current_skip_tracker()stashed_tensors:Dict[str,Optional[Tensor]]={}# Load skip tensors that might be popped.poppable_tensors={}batch=Batch(input)forns,nameinself.poppable():try:poppable_tensors[name]=skip_tracker.load(batch,ns,name)exceptKeyError:raiseRuntimeError(f"'{name}' has not been stashed")input=batch.values# Handle skip commands.defhandle_stash(name:str,tensor:Optional[Tensor])->None:ifnamenotinself.stashable_names:raiseRuntimeError(f"'{name}' has not been declared as stashable")stashed_tensors[name]=tensordefhandle_pop(name:str)->Optional[Tensor]:ifnamenotinself.poppable_names:raiseRuntimeError(f"'{name}' has not been declared as poppable")returnpoppable_tensors.pop(name)output=self.dispatch(input,handle_stash,handle_pop)# All declared skips must be stashed or popped.not_stashed=self.stashable_names-stashed_tensors.keys()ifnot_stashed:comma_names=", ".join("'%s'"%nforninnot_stashed)raiseRuntimeError(f"{comma_names} must be stashed but have not")not_popped=poppable_tensors.keys()ifnot_popped:comma_names=", ".join("'%s'"%nforninnot_popped)raiseRuntimeError(f"{comma_names} must be popped but have not")# Save stashed skip tensors.batch=Batch(output)forns,nameinself.stashable():tensor=stashed_tensors[name]skip_tracker.save(batch,ns,name,tensor)output=batch.valuesreturnoutput# TODO(sublee): Move to above of Skippable class for better read flow.
[docs]defskippable(stash:Iterable[str]=(),pop:Iterable[str]=(),)->Callable[[Type[SkippableModule]],Type[Skippable]]:"""The decorator to define a :class:`nn.Module <torch.nn.Module>` with skip connections. Decorated modules are called "skippable". This functionality works perfectly fine even when the module is not wrapped by :class:`~torch.distributed.pipeline.sync.Pipe`. Each skip tensor is managed by its name. Before manipulating skip tensors, a skippable module must statically declare the names for skip tensors by `stash` and/or `pop` parameters. Skip tensors with pre-declared name can be stashed by ``yield stash(name, tensor)`` or popped by ``tensor = yield pop(name)``. Here is an example with three layers. A skip tensor named "1to3" is stashed and popped at the first and last layer, respectively:: @skippable(stash=['1to3']) class Layer1(nn.Module): def forward(self, input): yield stash('1to3', input) return f1(input) class Layer2(nn.Module): def forward(self, input): return f2(input) @skippable(pop=['1to3']) class Layer3(nn.Module): def forward(self, input): skip_1to3 = yield pop('1to3') return f3(input) + skip_1to3 model = nn.Sequential(Layer1(), Layer2(), Layer3()) One skippable module can stash or pop multiple skip tensors:: @skippable(stash=['alice', 'bob'], pop=['carol']) class StashStashPop(nn.Module): def forward(self, input): yield stash('alice', f_alice(input)) yield stash('bob', f_bob(input)) carol = yield pop('carol') return input + carol Every skip tensor must be associated with exactly one pair of `stash` and `pop`. :class:`~torch.distributed.pipeline.sync.Pipe` checks this restriction automatically when wrapping a module. You can also check the restriction by :func:`verify_skippables` without :class:`~torch.distributed.pipeline.sync.Pipe`. """stashable_names=frozenset(stash)poppable_names=frozenset(pop)defextend_skippable(module_cls:Type[SkippableModule])->Type[Skippable]:name=module_cls.__name__bases=(Skippable,)attrs={"module_cls":module_cls,"stashable_names":stashable_names,"poppable_names":poppable_names}returntype(name,bases,attrs)returnextend_skippable
[docs]classstash:"""The command to stash a skip tensor. :: def forward(self, input): yield stash('name', input) return f(input) Args: name (str): name of skip tensor input (torch.Tensor or None): tensor to pass to the skip connection """__slots__=("name","tensor")def__init__(self,name:str,tensor:Optional[Tensor])->None:self.name=nameself.tensor=tensor
[docs]classpop:"""The command to pop a skip tensor. :: def forward(self, input): skip = yield pop('name') return f(input) + skip Args: name (str): name of skip tensor Returns: the skip tensor previously stashed by another layer under the same name """__slots__=("name",)def__init__(self,name:str)->None:self.name=name
[docs]defverify_skippables(module:nn.Sequential)->None:"""Verifies if the underlying skippable modules satisfy integrity. Every skip tensor must have only one pair of `stash` and `pop`. If there are one or more unmatched pairs, it will raise :exc:`TypeError` with the detailed messages. Here are a few failure cases. :func:`verify_skippables` will report failure for these cases:: # Layer1 stashes "1to3". # Layer3 pops "1to3". nn.Sequential(Layer1(), Layer2()) # └──── ? nn.Sequential(Layer2(), Layer3()) # ? ────┘ nn.Sequential(Layer1(), Layer2(), Layer3(), Layer3()) # └───────────────────┘ ^^^^^^ nn.Sequential(Layer1(), Layer1(), Layer2(), Layer3()) # ^^^^^^ └───────────────────┘ To use the same name for multiple skip tensors, they must be isolated by different namespaces. See :meth:`isolate() <torchpipe.skip.skippable.Skippable.isolate>`. Raises: TypeError: one or more pairs of `stash` and `pop` are not matched. """stashed:Set[Tuple[Namespace,str]]=set()popped:Set[Tuple[Namespace,str]]=set()msgs:List[str]=[]forlayer_name,layerinmodule.named_children():ifnotisinstance(layer,Skippable):continuefornameinlayer.stashable_names&layer.poppable_names:msg=f"'{layer_name}' declared '{name}' both as stashable and as poppable"msgs.append(msg)forns,nameinlayer.stashable():ifnameinlayer.poppable_names:continueif(ns,name)instashed:msg=f"'{layer_name}' redeclared '{name}' as stashable ""but not isolated by namespace"msgs.append(msg)continuestashed.add((ns,name))forns,nameinlayer.poppable():ifnameinlayer.stashable_names:continueif(ns,name)inpopped:msg=f"'{layer_name}' redeclared '{name}' as poppable ""but not isolated by namespace"msgs.append(msg)continueif(ns,name)notinstashed:msg=f"'{layer_name}' declared '{name}' as poppable but it was not stashed"msgs.append(msg)continuepopped.add((ns,name))for(_,name)instashed-popped:msg=f"no module declared '{name}' as poppable but stashed"msgs.append(msg)ifmsgs:raiseTypeError("one or more pairs of stash and pop do not match:\n\n%s"""%"\n".join("* %s"%xforxinmsgs))
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.