Program Listing for File init.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/init.h
)
#pragma once
#include <torch/csrc/Export.h>
#include <torch/enum.h>
#include <torch/types.h>
namespace torch {
namespace nn {
namespace init {
using NonlinearityType = std::variant<
enumtype::kLinear,
enumtype::kConv1D,
enumtype::kConv2D,
enumtype::kConv3D,
enumtype::kConvTranspose1D,
enumtype::kConvTranspose2D,
enumtype::kConvTranspose3D,
enumtype::kSigmoid,
enumtype::kTanh,
enumtype::kReLU,
enumtype::kLeakyReLU>;
using FanModeType = std::variant<enumtype::kFanIn, enumtype::kFanOut>;
} // namespace init
} // namespace nn
namespace nn {
namespace init {
TORCH_API double calculate_gain(
NonlinearityType nonlinearity,
double param = 0.01);
TORCH_API Tensor constant_(Tensor tensor, Scalar value);
TORCH_API Tensor dirac_(Tensor tensor);
TORCH_API Tensor eye_(Tensor matrix);
TORCH_API Tensor normal_(Tensor tensor, double mean = 0, double std = 1);
TORCH_API Tensor ones_(Tensor tensor);
TORCH_API Tensor orthogonal_(Tensor tensor, double gain = 1.0);
TORCH_API Tensor sparse_(Tensor tensor, double sparsity, double std = 0.01);
TORCH_API Tensor uniform_(Tensor tensor, double low = 0, double high = 1);
TORCH_API Tensor kaiming_normal_(
Tensor tensor,
double a = 0,
FanModeType mode = torch::kFanIn,
NonlinearityType nonlinearity = torch::kLeakyReLU);
TORCH_API Tensor kaiming_uniform_(
Tensor tensor,
double a = 0,
FanModeType mode = torch::kFanIn,
NonlinearityType nonlinearity = torch::kLeakyReLU);
TORCH_API Tensor xavier_normal_(Tensor tensor, double gain = 1.0);
TORCH_API Tensor xavier_uniform_(Tensor tensor, double gain = 1.0);
TORCH_API Tensor zeros_(Tensor tensor);
TORCH_API std::tuple<int64_t, int64_t> _calculate_fan_in_and_fan_out(
const Tensor& tensor);
} // namespace init
} // namespace nn
} // namespace torch