Template Function torch::nn::parallel::data_parallel¶
Defined in File data_parallel.h
Function Documentation¶
-
template<typename ModuleType>
Tensor torch::nn::parallel::data_parallel(ModuleType module, Tensor input, std::optional<std::vector<Device>> devices = std::nullopt, std::optional<Device> output_device = std::nullopt, int64_t dim = 0)¶ Evaluates
module(input)
in parallel across the givendevices
.If
devices
is not supplied, the invocation is parallelized across all available CUDA devices. Ifoutput_device
is supplied, the final, combined tensor will be placed on this device. If not, it defaults to the first device indevices
.In detail, this method performs the following four distinct steps:
Scatter the input to the given devices,
Replicate (deep clone) the model on each device,
Evaluate each module with its input on its device,
Gather the outputs of each replica into a single output tensor, located on the
output_device
.