Rate this Page

Class MNIST#

Inheritance Relationships#

Base Type#

Class Documentation#

class MNIST : public torch::data::datasets::Dataset<MNIST>#

The MNIST dataset.

Public Types

enum class Mode#

The mode in which the dataset is loaded.

Values:

enumerator kTrain#
enumerator kTest#

Public Functions

explicit MNIST(const std::string &root, Mode mode = Mode::kTrain)#

Loads the MNIST dataset from the root path.

The supplied root path should contain the content of the unzipped MNIST dataset, available from http://yann.lecun.com/exdb/mnist.

virtual Example get(size_t index) override#

Returns the Example at the given index.

virtual std::optional<size_t> size() const override#

Returns the size of the dataset.

bool is_train() const noexcept#

Returns true if this is the training subset of MNIST.

const Tensor &images() const#

Returns all images stacked into a single tensor.

const Tensor &targets() const#

Returns all targets stacked into a single tensor.