[docs]classDistributedSampler(Sampler):"""Sampler that restricts data loading to a subset of the dataset. It is especially useful in conjunction with :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each process can pass a DistributedSampler instance as a DataLoader sampler, and load a subset of the original dataset that is exclusive to it. .. note:: Dataset is assumed to be of constant size. Arguments: dataset: Dataset used for sampling. num_replicas (optional): Number of processes participating in distributed training. rank (optional): Rank of the current process within num_replicas. """def__init__(self,dataset,num_replicas=None,rank=None):ifnum_replicasisNone:num_replicas=get_world_size()ifrankisNone:rank=get_rank()self.dataset=datasetself.num_replicas=num_replicasself.rank=rankself.epoch=0self.num_samples=int(math.ceil(len(self.dataset)*1.0/self.num_replicas))self.total_size=self.num_samples*self.num_replicasdef__iter__(self):# deterministically shuffle based on epochg=torch.Generator()g.manual_seed(self.epoch)indices=list(torch.randperm(len(self.dataset),generator=g))# add extra samples to make it evenly divisibleindices+=indices[:(self.total_size-len(indices))]assertlen(indices)==self.total_size# subsampleoffset=self.num_samples*self.rankindices=indices[offset:offset+self.num_samples]assertlen(indices)==self.num_samplesreturniter(indices)def__len__(self):returnself.num_samplesdefset_epoch(self,epoch):self.epoch=epoch