Shortcuts

Class AnyModule

Page Contents

Class Documentation

class torch::nn::AnyModule

Stores a type erased Module.

The PyTorch C++ API does not impose an interface on the signature of forward() in Module subclasses. This gives you complete freedom to design your forward() methods to your liking. However, this also means there is no unified base type you could store in order to call forward() polymorphically for any module. This is where the AnyModule comes in. Instead of inheritance, it relies on type erasure for polymorphism.

An AnyModule can store any nn::Module subclass that provides a forward() method. This forward() may accept any types and return any type. Once stored in an AnyModule, you can invoke the underlying module’s forward() by calling AnyModule::forward() with the arguments you would supply to the stored module (though see one important limitation below). Example:

struct GenericTrainer {
  torch::nn::AnyModule module;

  void train(torch::Tensor input) {
    module.forward(input);
  }
};

GenericTrainer trainer1{torch::nn::Linear(3, 4)};
GenericTrainer trainer2{torch::nn::Conv2d(3, 4, 2)};

As AnyModule erases the static type of the stored module (and its forward() method) to achieve polymorphism, type checking of arguments is moved to runtime. That is, passing an argument with an incorrect type to an AnyModule will compile, but throw an exception at runtime:

torch::nn::AnyModule module(torch::nn::Linear(3, 4));
// Linear takes a tensor as input, but we are passing an integer.
// This will compile, but throw a `torch::Error` exception at runtime.
module.forward(123);

Attention

One noteworthy limitation of AnyModule is that its forward() method does not support implicit conversion of argument types. For example, if the stored module’s forward() method accepts a float and you call any_module.forward(3.4) (where 3.4 is a double), this will throw an exception.

The return type of the AnyModule’s forward() method is controlled via the first template argument to AnyModule::forward(). It defaults to torch::Tensor. To change it, you can write any_module.forward<int>(), for example.

torch::nn::AnyModule module(torch::nn::Linear(3, 4));
auto output = module.forward(torch::ones({2, 3}));

struct IntModule {
  int forward(int x) { return x; }
};
torch::nn::AnyModule module(IntModule{});
int output = module.forward<int>(5);

The only other method an AnyModule provides access to on the stored module is clone(). However, you may acquire a handle on the module via .ptr(), which returns a shared_ptr<nn::Module>. Further, if you know the concrete type of the stored module, you can get a concrete handle to it using .get<T>() where T is the concrete module type.

torch::nn::AnyModule module(torch::nn::Linear(3, 4));
std::shared_ptr<nn::Module> ptr = module.ptr();
torch::nn::Linear linear(module.get<torch::nn::Linear>());

Public Functions

AnyModule() = default

A default-constructed AnyModule is in an empty state.

template<typename ModuleType>
AnyModule(std::shared_ptr<ModuleType> module)

Constructs an AnyModule from a shared_ptr to concrete module object.

template<typename ModuleType, typename = torch::detail::enable_if_module_t<ModuleType>>
AnyModule(ModuleType &&module)

Constructs an AnyModule from a concrete module object.

template<typename ModuleType>
AnyModule(const ModuleHolder<ModuleType> &module_holder)

Constructs an AnyModule from a module holder.

AnyModule(AnyModule&&) = default

Move construction and assignment is allowed, and follows the default behavior of move for std::unique_ptr.

AnyModule &operator=(AnyModule&&) = default
AnyModule(const AnyModule &other)

Creates a shallow copy of an AnyModule.

AnyModule &operator=(const AnyModule &other)
AnyModule clone(optional<Device> device = nullopt) const

Creates a deep copy of an AnyModule if it contains a module, else an empty AnyModule if it is empty.

template<typename ModuleType>
AnyModule &operator=(std::shared_ptr<ModuleType> module)

Assigns a module to the AnyModule (to circumvent the explicit constructor).

template<typename ...ArgumentTypes>
AnyValue any_forward(ArgumentTypes&&... arguments)

Invokes forward() on the contained module with the given arguments, and returns the return value as an AnyValue.

Use this method when chaining AnyModules in a loop.

template<typename ReturnType = torch::Tensor, typename ...ArgumentTypes>
ReturnType forward(ArgumentTypes&&... arguments)

Invokes forward() on the contained module with the given arguments, and casts the returned AnyValue to the supplied ReturnType (which defaults to torch::Tensor).

template<typename T, typename = torch::detail::enable_if_module_t<T>>
T &get()

Attempts to cast the underlying module to the given module type.

Throws an exception if the types do not match.

template<typename T, typename = torch::detail::enable_if_module_t<T>>
const T &get() const

Attempts to cast the underlying module to the given module type.

Throws an exception if the types do not match.

template<typename T, typename ContainedType = typename T::ContainedType>
T get() const

Returns the contained module in a nn::ModuleHolder subclass if possible (i.e.

if T has a constructor for the underlying module type).

std::shared_ptr<Module> ptr() const

Returns a std::shared_ptr whose dynamic type is that of the underlying module.

template<typename T, typename = torch::detail::enable_if_module_t<T>>
std::shared_ptr<T> ptr() const

Like ptr(), but casts the pointer to the given type.

const std::type_info &type_info() const

Returns the type_info object of the contained value.

bool is_empty() const noexcept

Returns true if the AnyModule does not contain a module.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources