Class AnyModule¶
Defined in File any.h
Page Contents
Class Documentation¶
-
class AnyModule¶
Stores a type erased
Module
.The PyTorch C++ API does not impose an interface on the signature of
forward()
inModule
subclasses. This gives you complete freedom to design yourforward()
methods to your liking. However, this also means there is no unified base type you could store in order to callforward()
polymorphically for any module. This is where theAnyModule
comes in. Instead of inheritance, it relies on type erasure for polymorphism.An
AnyModule
can store anynn::Module
subclass that provides aforward()
method. Thisforward()
may accept any types and return any type. Once stored in anAnyModule
, you can invoke the underlying module’sforward()
by callingAnyModule::forward()
with the arguments you would supply to the stored module (though see one important limitation below). Example:struct GenericTrainer { torch::nn::AnyModule module; void train(torch::Tensor input) { module.forward(input); } }; GenericTrainer trainer1{torch::nn::Linear(3, 4)}; GenericTrainer trainer2{torch::nn::Conv2d(3, 4, 2)};
As
AnyModule
erases the static type of the stored module (and itsforward()
method) to achieve polymorphism, type checking of arguments is moved to runtime. That is, passing an argument with an incorrect type to anAnyModule
will compile, but throw an exception at runtime:torch::nn::AnyModule module(torch::nn::Linear(3, 4)); // Linear takes a tensor as input, but we are passing an integer. // This will compile, but throw a `torch::Error` exception at runtime. module.forward(123);
Attention
One noteworthy limitation of AnyModule is that its forward() method does not support implicit conversion of argument types. For example, if the stored module’s forward() method accepts a float and you call any_module.forward(3.4) (where 3.4 is a double), this will throw an exception.
The return type of the
AnyModule
’sforward()
method is controlled via the first template argument toAnyModule::forward()
. It defaults totorch::Tensor
. To change it, you can writeany_module.forward<int>()
, for example.torch::nn::AnyModule module(torch::nn::Linear(3, 4)); auto output = module.forward(torch::ones({2, 3})); struct IntModule { int forward(int x) { return x; } }; torch::nn::AnyModule module(IntModule{}); int output = module.forward<int>(5);
The only other method an
AnyModule
provides access to on the stored module isclone()
. However, you may acquire a handle on the module via.ptr()
, which returns ashared_ptr<nn::Module>
. Further, if you know the concrete type of the stored module, you can get a concrete handle to it using.get<T>()
whereT
is the concrete module type.torch::nn::AnyModule module(torch::nn::Linear(3, 4)); std::shared_ptr<nn::Module> ptr = module.ptr(); torch::nn::Linear linear(module.get<torch::nn::Linear>());
Public Functions
Constructs an
AnyModule
from ashared_ptr
to concrete module object.
-
template<typename ModuleType, typename = torch::detail::enable_if_module_t<ModuleType>>
explicit AnyModule(ModuleType &&module)¶ Constructs an
AnyModule
from a concrete module object.
-
template<typename ModuleType>
explicit AnyModule(const ModuleHolder<ModuleType> &module_holder)¶ Constructs an
AnyModule
from a module holder.
-
AnyModule(AnyModule&&) = default¶
Move construction and assignment is allowed, and follows the default behavior of move for
std::unique_ptr
.
-
inline AnyModule clone(std::optional<Device> device = std::nullopt) const¶
Creates a deep copy of an
AnyModule
if it contains a module, else an emptyAnyModule
if it is empty.
Assigns a module to the
AnyModule
(to circumvent the explicit constructor).
-
template<typename ...ArgumentTypes>
AnyValue any_forward(ArgumentTypes&&... arguments)¶ Invokes
forward()
on the contained module with the given arguments, and returns the return value as anAnyValue
.Use this method when chaining
AnyModule
s in a loop.
-
template<typename ReturnType = torch::Tensor, typename ...ArgumentTypes>
ReturnType forward(ArgumentTypes&&... arguments)¶ Invokes
forward()
on the contained module with the given arguments, and casts the returnedAnyValue
to the suppliedReturnType
(which defaults totorch::Tensor
).
-
template<typename T, typename = torch::detail::enable_if_module_t<T>>
T &get()¶ Attempts to cast the underlying module to the given module type.
Throws an exception if the types do not match.
-
template<typename T, typename = torch::detail::enable_if_module_t<T>>
const T &get() const¶ Attempts to cast the underlying module to the given module type.
Throws an exception if the types do not match.
-
template<typename T, typename ContainedType = typename T::ContainedType>
T get() const¶ Returns the contained module in a
nn::ModuleHolder
subclass if possible (i.e.if
T
has a constructor for the underlying module type).
-
inline std::shared_ptr<Module> ptr() const¶
Returns a
std::shared_ptr
whose dynamic type is that of the underlying module.
Like
ptr()
, but casts the pointer to the given type.
-
inline const std::type_info &type_info() const¶
Returns the
type_info
object of the contained value.