Shortcuts

Class FunctionalImpl

Inheritance Relationships

Base Type

Class Documentation

class FunctionalImpl : public torch::nn::Cloneable<FunctionalImpl>

Wraps a function in a Module.

The Functional module allows wrapping an arbitrary function or function object in an nn::Module. This is primarily handy for usage in Sequential.

Sequential sequential(
  Linear(3, 4),
  Functional(torch::relu),
  BatchNorm1d(3),
  Functional(torch::elu, /*alpha=*&zwj;/1));

While a Functional module only accepts a single Tensor as input, it is possible for the wrapped function to accept further arguments. However, these have to be bound at construction time. For example, if you want to wrap torch::leaky_relu, which accepts a slope scalar as its second argument, with a particular value for its slope in a Functional module, you could write

Functional(torch::leaky_relu, /*slope=*&zwj;/0.5)

The value of 0.5 is then stored within the Functional object and supplied to the function call at invocation time. Note that such bound values are evaluated eagerly and stored a single time. See the documentation of std::bind for more information on the semantics of argument binding.

Attention

After passing any bound arguments, the function must accept a single tensor and return a single tensor.

Note that Functional overloads the call operator (operator()) such that you can invoke it with my_func(...).

Public Types

using Function = std::function<Tensor(Tensor)>

Public Functions

explicit FunctionalImpl(Function function)

Constructs a Functional from a function object.

template<typename SomeFunction, typename ...Args, typename = std::enable_if_t<(sizeof...(Args) > 0)>>
inline explicit FunctionalImpl(SomeFunction original_function, Args&&... args)
virtual void reset() override

reset() must perform initialization of all members with reference semantics, most importantly parameters, buffers and submodules.

virtual void pretty_print(std::ostream &stream) const override

Pretty prints the Functional module into the given stream.

Tensor forward(Tensor input)

Forwards the input tensor to the underlying (bound) function object.

Tensor operator()(Tensor input)

Calls forward(input).

virtual bool is_serializable() const override

Returns whether the Module is serializable.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources