Shortcuts

Template Class DistributedSampler

Inheritance Relationships

Base Type

Derived Types

Class Documentation

template<typename BatchRequest = std::vector<size_t>>
class DistributedSampler : public torch::data::samplers::Sampler<std::vector<size_t>>

A Sampler that selects a subset of indices to sample from and defines a sampling behavior.

In a distributed setting, this selects a subset of the indices depending on the provided num_replicas and rank parameters. The Sampler performs a rounding operation based on the allow_duplicates parameter to decide the local sample count.

Subclassed by torch::data::samplers::DistributedRandomSampler, torch::data::samplers::DistributedSequentialSampler

Public Functions

inline DistributedSampler(size_t size, size_t num_replicas = 1, size_t rank = 0, bool allow_duplicates = true)
inline void set_epoch(size_t epoch)

Set the epoch for the current enumeration.

This can be used to alter the sample selection and shuffling behavior.

inline size_t epoch() const

Protected Functions

inline size_t local_sample_count()

Protected Attributes

size_t size_
size_t num_replicas_
size_t rank_
size_t epoch_ = {0}
bool allow_duplicates_

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources