[docs]defauto_dataloader(dataset:Dataset,**kwargs:Any)->Union[DataLoader,"_MpDeviceLoader"]:"""Helper method to create a dataloader adapted for non-distributed and distributed configurations (supporting all available backends from :meth:`~ignite.distributed.utils.available_backends()`). Internally, we create a dataloader with provided kwargs while applying the following updates: - batch size is scaled by world size: ``batch_size / world_size`` if larger or equal world size. - number of workers is scaled by number of local processes: ``num_workers / nprocs`` if larger or equal world size. - if no sampler provided by user, a `torch DistributedSampler`_ is setup. - if a `torch DistributedSampler`_ is provided by user, it is used without wrapping it. - if another sampler is provided, it is wrapped by :class:`~ignite.distributed.auto.DistributedProxySampler`. - if the default device is 'cuda', `pin_memory` is automatically set to `True`. .. warning:: Custom batch sampler is not adapted for distributed configuration. Please, make sure that provided batch sampler is compatible with distributed configuration. Args: dataset: input torch dataset. If input dataset is `torch IterableDataset`_ then dataloader will be created without any distributed sampling. Please, make sure that the dataset itself produces different data on different ranks. kwargs: keyword arguments for `torch DataLoader`_. Returns: `torch DataLoader`_ or `XLA MpDeviceLoader`_ for XLA devices Examples: .. code-block:: python import ignite.distribted as idist train_loader = idist.auto_dataloader( train_dataset, batch_size=32, num_workers=4, shuffle=True, pin_memory="cuda" in idist.device().type, drop_last=True, ) .. _torch DataLoader: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader .. _XLA MpDeviceLoader: https://pytorch.org/xla/release/2.0/index.html#running-on-multiple-xla-devices-with-multi-processing .. _torch DistributedSampler: https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler .. _torch IterableDataset: https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset """rank=idist.get_rank()world_size=idist.get_world_size()logger=setup_logger(__name__+".auto_dataloader")ifworld_size>1:if"batch_size"inkwargsandkwargs["batch_size"]>=world_size:kwargs["batch_size"]//=world_sizenproc=idist.get_nproc_per_node()if"num_workers"inkwargsandkwargs["num_workers"]>=nproc:kwargs["num_workers"]=(kwargs["num_workers"]+nproc-1)//nprocif"batch_sampler"notinkwargs:ifisinstance(dataset,IterableDataset):logger.info("Found iterable dataset, dataloader will be created without any distributed sampling. ""Please, make sure that the dataset itself produces different data on different ranks.")else:sampler:Optional[Union[DistributedProxySampler,DistributedSampler,Sampler]]sampler=kwargs.get("sampler",None)ifisinstance(sampler,DistributedSampler):ifsampler.rank!=rank:warnings.warn(f"Found distributed sampler with rank={sampler.rank}, but process rank is {rank}")ifsampler.num_replicas!=world_size:warnings.warn(f"Found distributed sampler with num_replicas={sampler.num_replicas}, "f"but world size is {world_size}")elifsamplerisNone:# removes "shuffle" from kwargs if sampler is usedshuffle=kwargs.pop("shuffle",True)sampler=DistributedSampler(dataset,num_replicas=world_size,rank=rank,shuffle=shuffle)else:sampler=DistributedProxySampler(sampler,num_replicas=world_size,rank=rank)kwargs["sampler"]=samplerelse:warnings.warn("Found batch_sampler in provided kwargs. Please, make sure that it is compatible ""with distributed configuration")ifidist.has_xla_supportandidist.backend()==idist_xla.XLA_TPUandkwargs.get("pin_memory",False):# TODO: How about XLA GPU ?warnings.warn("Found incompatible options: xla support and pin_memory args equal True. ""Argument `pin_memory=False` will be used to construct data loader.")kwargs["pin_memory"]=Falseelse:kwargs["pin_memory"]=kwargs.get("pin_memory","cuda"inidist.device().type)logger.info(f"Use data loader kwargs for dataset '{repr(dataset)[:20].strip()}': \n\t{kwargs}")dataloader=DataLoader(dataset,**kwargs)ifidist.has_xla_supportandidist.backend()==idist_xla.XLA_TPUandworld_size>1:logger.info("DataLoader is wrapped by `MpDeviceLoader` on XLA")mp_device_loader_cls=_MpDeviceLoadertry:fromtorch_xla.distributed.parallel_loaderimportMpDeviceLoadermp_device_loader_cls=MpDeviceLoaderexceptImportError:passmp_dataloader=mp_device_loader_cls(dataloader,idist.device())mp_dataloader.sampler=dataloader.sampler# type: ignore[attr-defined]returnmp_dataloaderreturndataloader
[docs]defauto_model(model:nn.Module,sync_bn:bool=False,**kwargs:Any)->nn.Module:"""Helper method to adapt provided model for non-distributed and distributed configurations (supporting all available backends from :meth:`~ignite.distributed.utils.available_backends()`). Internally, we perform to following: - send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device. - wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1. - wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available. - broadcast the initial variable states from rank 0 to all other processes if Horovod distributed framework is used. Args: model: model to adapt. sync_bn: if True, applies `torch convert_sync_batchnorm`_ to the model for native torch distributed only. Default, False. Note, if using Nvidia/Apex, batchnorm conversion should be applied before calling ``amp.initialize``. kwargs: kwargs to model's wrapping class: `torch DistributedDataParallel`_ or `torch DataParallel`_ if applicable. Please, make sure to use acceptable kwargs for given backend. Returns: torch.nn.Module Examples: .. code-block:: python import ignite.distribted as idist model = idist.auto_model(model) In addition with NVidia/Apex, it can be used in the following way: .. code-block:: python import ignite.distribted as idist model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) model = idist.auto_model(model) .. _torch DistributedDataParallel: https://pytorch.org/docs/stable/generated/torch.nn.parallel. DistributedDataParallel.html .. _torch DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html .. _torch convert_sync_batchnorm: https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html# torch.nn.SyncBatchNorm.convert_sync_batchnorm .. versionchanged:: 0.4.2 - Added Horovod distributed framework. - Added ``sync_bn`` argument. .. versionchanged:: 0.4.3 Added kwargs to ``idist.auto_model``. """logger=setup_logger(__name__+".auto_model")# Put model's parameters to device if its parameters are not on the devicedevice=idist.device()ifnotall([p.device==deviceforpinmodel.parameters()]):model.to(device)# distributed data parallel modelifidist.get_world_size()>1:bnd=idist.backend()ifidist.has_native_dist_supportandbndin(idist_native.NCCL,idist_native.GLOO,idist_native.MPI):ifsync_bn:logger.info("Convert batch norm to sync batch norm")model=nn.SyncBatchNorm.convert_sync_batchnorm(model)iftorch.cuda.is_available():if"device_ids"inkwargs:raiseValueError(f"Argument kwargs should not contain 'device_ids', but got {kwargs}")lrank=idist.get_local_rank()logger.info(f"Apply torch DistributedDataParallel on model, device id: {lrank}")kwargs["device_ids"]=[lrank,]else:logger.info("Apply torch DistributedDataParallel on model")model=torch.nn.parallel.DistributedDataParallel(model,**kwargs)elifidist.has_hvd_supportandbnd==idist_hvd.HOROVOD:importhorovod.torchashvdlogger.info("Broadcast the initial variable states from rank 0 to all other processes")hvd.broadcast_parameters(model.state_dict(),root_rank=0)# not distributed but multiple GPUs reachable so data parallel modeleliftorch.cuda.device_count()>1and"cuda"inidist.device().type:logger.info("Apply torch DataParallel on model")model=torch.nn.parallel.DataParallel(model,**kwargs)returnmodel
[docs]defauto_optim(optimizer:Optimizer,**kwargs:Any)->Optimizer:"""Helper method to adapt optimizer for non-distributed and distributed configurations (supporting all available backends from :meth:`~ignite.distributed.utils.available_backends()`). Internally, this method is no-op for non-distributed and torch native distributed configuration. For XLA distributed configuration, we create a new class that inherits from provided optimizer. The goal is to override the `step()` method with specific `xm.optimizer_step`_ implementation. For Horovod distributed configuration, optimizer is wrapped with Horovod Distributed Optimizer and its state is broadcasted from rank 0 to all other processes. Args: optimizer: input torch optimizer kwargs: kwargs to Horovod backend's DistributedOptimizer. Returns: Optimizer Examples: .. code-block:: python import ignite.distributed as idist optimizer = idist.auto_optim(optimizer) .. _xm.optimizer_step: https://pytorch.org/xla/release/1.5/index.html#torch_xla.core.xla_model.optimizer_step .. versionchanged:: 0.4.2 Added Horovod distributed optimizer. .. versionchanged:: 0.4.7 Added kwargs to ``idist.auto_optim``. """bnd=idist.backend()ifidist.has_xla_supportandbnd==idist_xla.XLA_TPU:cls=type(optimizer.__class__.__name__,(optimizer.__class__,),dict(_XLADistributedOptimizer.__dict__))returncls(optimizer)ifidist.has_hvd_supportandbnd==idist_hvd.HOROVOD:importhorovod.torchashvdoptimizer=hvd.DistributedOptimizer(optimizer,**kwargs)hvd.broadcast_optimizer_state(optimizer,root_rank=0)returnoptimizerreturnoptimizer
[docs]classDistributedProxySampler(DistributedSampler):"""Distributed sampler proxy to adapt user's sampler for distributed data parallelism configuration. Code is based on https://github.com/pytorch/pytorch/issues/23430#issuecomment-562350407 Args: sampler: Input torch data sampler. num_replicas: Number of processes participating in distributed training. rank: Rank of the current process within ``num_replicas``. .. note:: Input sampler is assumed to have a constant size. """def__init__(self,sampler:Sampler,num_replicas:Optional[int]=None,rank:Optional[int]=None)->None:ifnotisinstance(sampler,Sampler):raiseTypeError(f"Argument sampler should be instance of torch Sampler, but given: {type(sampler)}")ifisinstance(sampler,DistributedSampler):raiseTypeError("Argument sampler must not be a distributed sampler already")ifnothasattr(sampler,"__len__"):raiseTypeError("Argument sampler should have length")super(DistributedProxySampler,self).__init__(sampler,num_replicas=num_replicas,rank=rank,shuffle=False# type: ignore[arg-type])self.sampler=samplerdef__iter__(self)->Iterator:# deterministically shuffle based on epochtorch.manual_seed(self.epoch)indices:List=[]whilelen(indices)<self.total_size:indices+=list(self.sampler)iflen(indices)>self.total_size:indices=indices[:self.total_size]# subsampleindices=indices[self.rank:self.total_size:self.num_replicas]iflen(indices)!=self.num_samples:raiseRuntimeError(f"{len(indices)} vs {self.num_samples}")returniter(indices)
ifidist.has_xla_support:importtorch_xla.core.xla_modelasxmfromtorch_xla.distributed.parallel_loaderimportParallelLoaderclass_MpDeviceLoader:# https://github.com/pytorch/xla/pull/2117# From pytorch/xla if `torch_xla.distributed.parallel_loader.MpDeviceLoader` is not availabledef__init__(self,loader:Any,device:torch.device,**kwargs:Any)->None:self._loader=loaderself._device=deviceself._parallel_loader_kwargs=kwargsdef__iter__(self)->Iterator:parallel_loader=ParallelLoader(self._loader,[self._device],**self._parallel_loader_kwargs)returnparallel_loader.per_device_loader(self._device)def__len__(self)->int:returnlen(self._loader)class_XLADistributedOptimizer(Optimizer):def__init__(self,optimizer:Optimizer)->None:super(self.__class__,self).__init__(optimizer.param_groups)# type: ignore[call-arg]self.wrapped_optimizer=optimizerdefstep(self,closure:Any=None)->Any:xm.optimizer_step(self.wrapped_optimizer,barrier=True)