Class FunctionalImpl¶
Defined in File functional.h
Page Contents
Inheritance Relationships¶
Base Type¶
public torch::nn::Cloneable< FunctionalImpl >
(Template Class Cloneable)
Class Documentation¶
-
class FunctionalImpl : public torch::nn::Cloneable<FunctionalImpl>¶
Wraps a function in a
Module
.The
Functional
module allows wrapping an arbitrary function or function object in annn::Module
. This is primarily handy for usage inSequential
.Sequential sequential( Linear(3, 4), Functional(torch::relu), BatchNorm1d(3), Functional(torch::elu, /*alpha=*‍/1));
While a
Functional
module only accepts a singleTensor
as input, it is possible for the wrapped function to accept further arguments. However, these have to be bound at construction time. For example, if you want to wraptorch::leaky_relu
, which accepts aslope
scalar as its second argument, with a particular value for itsslope
in aFunctional
module, you could writeFunctional(torch::leaky_relu, /*slope=*‍/0.5)
The value of
0.5
is then stored within theFunctional
object and supplied to the function call at invocation time. Note that such bound values are evaluated eagerly and stored a single time. See the documentation of std::bind for more information on the semantics of argument binding.Attention
After passing any bound arguments, the function must accept a single tensor and return a single tensor.
Note that
Functional
overloads the call operator (operator()
) such that you can invoke it withmy_func(...)
.Public Types
-
using Function = std::function<Tensor(Tensor)>¶
Public Functions
-
explicit FunctionalImpl(Function function)¶
Constructs a
Functional
from a function object.
-
template<typename SomeFunction, typename ...Args, typename = std::enable_if_t<(sizeof...(Args) > 0)>>
inline explicit FunctionalImpl(SomeFunction original_function, Args&&... args)¶
-
virtual void reset() override¶
reset()
must perform initialization of all members with reference semantics, most importantly parameters, buffers and submodules.
-
virtual void pretty_print(std::ostream &stream) const override¶
Pretty prints the
Functional
module into the givenstream
.
-
Tensor forward(Tensor input)¶
Forwards the
input
tensor to the underlying (bound) function object.
-
Tensor operator()(Tensor input)¶
Calls forward(input).
-
using Function = std::function<Tensor(Tensor)>¶