[docs]defupdate_dataloader(dataloader:DataLoader,new_batch_sampler:BatchSampler)->DataLoader:"""Helper function to replace current batch sampler of the dataloader by a new batch sampler. Function returns new dataloader with new batch sampler. Args: dataloader: input dataloader new_batch_sampler: new batch sampler to use Returns: DataLoader """params_keys=[kforkindataloader.__dict__.keys()ifnotk.startswith("_")]forkin["batch_size","sampler","drop_last","batch_sampler","dataset_kind"]:ifkinparams_keys:params_keys.remove(k)params={k:getattr(dataloader,k)forkinparams_keys}params["batch_sampler"]=new_batch_samplerreturntype(dataloader)(**params)
[docs]classReproducibleBatchSampler(BatchSampler):"""Reproducible batch sampler. This class internally iterates and stores indices of the input batch sampler. This helps to start providing data batches from an iteration in a deterministic way. Args: batch_sampler: batch sampler same as used with `torch.utils.data.DataLoader`. start_iteration: optional start iteration. Examples: Setup dataloader with `ReproducibleBatchSampler` and start providing data batches from an iteration .. code-block:: python from ignite.engine.deterministic import update_dataloader dataloader = update_dataloader(dataloader, ReproducibleBatchSampler(dataloader.batch_sampler)) # rewind dataloader to a specific iteration: dataloader.batch_sampler.start_iteration = start_iteration """def__init__(self,batch_sampler:BatchSampler,start_iteration:Optional[int]=None):ifnotisinstance(batch_sampler,BatchSampler):raiseTypeError("Argument batch_sampler should be torch.utils.data.sampler.BatchSampler")self.batch_indices:List=[]self.batch_sampler=batch_samplerself.start_iteration=start_iterationself.sampler=self.batch_sampler.sampler
[docs]defkeep_random_state(func:Callable)->Callable:"""Helper decorator to keep random state of torch, numpy and random intact while executing a function. For more details on usage, please see :ref:`Dataflow synchronization`. Args: func: function to decorate """@wraps(func)defwrapper(*args:Any,**kwargs:Any)->None:rng_states=_get_rng_states()func(*args,**kwargs)_set_rng_states(rng_states)returnwrapper
[docs]classDeterministicEngine(Engine):"""Deterministic engine derived from :class:`~ignite.engine.engine.Engine`. "Deterministic" run is done by adding additional handlers to synchronize the dataflow and overriding some methods of :class:`~ignite.engine.engine.Engine`: .. code-block:: python for e in range(num_epochs): set_seed(seed_offset + e) if resume: setup_saved_rng_states() do_single_epoch_iterations(dataloader) If input data provider is `DataLoader`, its batch sampler is replaced by :class:`~ignite.engine.deterministic.ReproducibleBatchSampler`. .. code-block:: python for e in range(num_epochs): set_seed(seed_offset + e) setup_sampling(dataloader) if resume: setup_saved_rng_states() do_single_epoch_iterations(dataloader) Internally, `torch.backends.cudnn.deterministic = True` and `torch.backends.cudnn.benchmark = False` are also applied. For more details about dataflow synchronization, please see :ref:`Dataflow synchronization`. .. Note :: This class can produce exactly the same dataflow when resuming the run from an epoch (or more precisely from dataflow restart) and using torch `DataLoader` with `num_workers > 1` as data provider. Args: process_function: A function receiving a handle to the engine and the current batch in each iteration, and returns data to be stored in the engine's state. """def__init__(self,process_function:Callable[[Engine,Any],Any]):super(DeterministicEngine,self).__init__(process_function)self.state_dict_user_keys.append("rng_states")ifnothasattr(self.state,"rng_states"):setattr(self.state,"rng_states",None)self.add_event_handler(Events.STARTED,self._init_run)self.add_event_handler(Events.DATALOADER_STOP_ITERATION|Events.TERMINATE_SINGLE_EPOCH,self._setup_seed)
def_init_run(self)->None:self.state.seed=int(torch.randint(0,int(1e9),(1,)).item())iftorch.cuda.is_available():ifhasattr(torch,"use_deterministic_algorithms"):torch.use_deterministic_algorithms(True,warn_only=True)else:torch.backends.cudnn.deterministic=Truetorch.backends.cudnn.benchmark=Falsedef_setup_engine(self)->None:ifself.state.dataloaderisNone:raiseValueError("Deterministic engine does not support the option of data=None. Please, provide data as iterable")self._dataloader_len=self._get_data_length(self.state.dataloader)# if input data is torch dataloader we replace batch sampler by a batch sampler# such that its random sampling indices are reproducible by prefetching them before data iterationifisinstance(self.state.dataloader,DataLoader):# attribute _dataset_kind is introduced since 1.3.0 => before 1.3.0 all datasets are map-likecan_patch_dataloader=Trueifhasattr(self.state.dataloader,"_dataset_kind"):fromtorch.utils.data.dataloaderimport_DatasetKind_dataloader_kind=self.state.dataloader._dataset_kindcan_patch_dataloader=_dataloader_kind==_DatasetKind.Mapifcan_patch_dataloader:ifself._dataloader_lenisnotNoneandhasattr(self.state.dataloader.sampler,"epoch"):ifself._dataloader_len!=self.state.epoch_length:warnings.warn("When defined engine's epoch length is different of input dataloader length, ""distributed sampler indices can not be setup in a reproducible manner")batch_sampler=self.state.dataloader.batch_samplerifnot(batch_samplerisNoneorisinstance(batch_sampler,ReproducibleBatchSampler)):self.state.dataloader=update_dataloader(self.state.dataloader,ReproducibleBatchSampler(batch_sampler)# type: ignore[arg-type])iteration=self.state.iterationself._dataloader_iter=self._from_iteration(iteration)# Below we define initial counter value for _run_once_on_dataset to measure a single epochifself.state.epoch_lengthisnotNone:iteration%=self.state.epoch_lengthself._init_iter=iteration# restore rng state if in the middlein_the_middle=self.state.iteration%self._dataloader_len>0ifself._dataloader_lenisnotNoneelseFalserng_states=getattr(self.state,"rng_states",None)ifrng_statesisnotNoneandin_the_middle:_set_rng_states(rng_states)setattr(self.state,"rng_states",None)def_from_iteration(self,iteration:int)->Iterator:ifself.state.dataloaderisNone:raiseRuntimeError("Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error.")data=self.state.dataloaderifisinstance(data,DataLoader):try:# following is unsafe for IterableDatasetsiteration%=len(data.batch_sampler)# type: ignore[arg-type]# Synchronize dataflow according to state.iterationself._setup_seed()ifiteration>0:# batch sampler is ReproducibleBatchSamplerdata.batch_sampler.start_iteration=iteration# type: ignore[union-attr]returniter(data)exceptTypeErrorase:# Probably we can do nothing with DataLoader built upon IterableDatasetspassself.logger.info("Resuming from iteration for provided data will fetch data until required iteration ...")ifhasattr(data,"__len__"):iteration%=len(data)# type: ignore[arg-type]# Synchronize dataflow from the beginingself._setup_seed(iteration=0)data_iter=iter(data)counter=0whilecounter<iteration:try:next(data_iter)counter+=1exceptStopIteration:data_iter=iter(data)returndata_iterdef_setup_seed(self,_:Any=None,iter_counter:Optional[int]=None,iteration:Optional[int]=None)->None:ifiter_counterisNone:le=self._dataloader_lenifself._dataloader_lenisnotNoneelse1elifnotiter_counter>0:raiseValueError("iter_counter should be positive value")else:le=iter_counterifiterationisNone:iteration=self.state.iterationmanual_seed(self.state.seed+iteration//le)# type: ignore[operator]