Program Listing for File linear.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/modules/linear.h
)
#pragma once
#include <torch/nn/cloneable.h>
#include <torch/nn/functional/linear.h>
#include <torch/nn/module.h>
#include <torch/nn/options/linear.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>
#include <cstddef>
#include <vector>
namespace torch {
namespace nn {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Identity ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API IdentityImpl : public Cloneable<IdentityImpl> {
public:
void reset() override;
void pretty_print(std::ostream& stream) const override;
Tensor forward(const Tensor& input);
};
TORCH_MODULE(Identity);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Linear ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API LinearImpl : public Cloneable<LinearImpl> {
public:
LinearImpl(int64_t in_features, int64_t out_features)
: LinearImpl(LinearOptions(in_features, out_features)) {}
explicit LinearImpl(const LinearOptions& options_);
void reset() override;
void reset_parameters();
void pretty_print(std::ostream& stream) const override;
Tensor forward(const Tensor& input);
LinearOptions options;
Tensor weight;
Tensor bias;
};
TORCH_MODULE(Linear);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Flatten ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API FlattenImpl : public Cloneable<FlattenImpl> {
public:
explicit FlattenImpl(const FlattenOptions& options_ = {});
void reset() override;
void pretty_print(std::ostream& stream) const override;
Tensor forward(const Tensor& input);
FlattenOptions options;
};
TORCH_MODULE(Flatten);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Unflatten
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API UnflattenImpl : public Cloneable<UnflattenImpl> {
public:
UnflattenImpl(int64_t dim, std::vector<int64_t> sizes)
: UnflattenImpl(UnflattenOptions(dim, sizes)) {}
UnflattenImpl(std::string dimname, UnflattenOptions::namedshape_t namedshape)
: UnflattenImpl(UnflattenOptions(dimname, namedshape)) {}
explicit UnflattenImpl(UnflattenOptions options_);
void reset() override;
void pretty_print(std::ostream& stream) const override;
Tensor forward(const Tensor& input);
UnflattenOptions options;
};
TORCH_MODULE(Unflatten);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Bilinear ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API BilinearImpl : public Cloneable<BilinearImpl> {
public:
BilinearImpl(int64_t in1_features, int64_t in2_features, int64_t out_features)
: BilinearImpl(
BilinearOptions(in1_features, in2_features, out_features)) {}
explicit BilinearImpl(const BilinearOptions& options_);
void reset() override;
void reset_parameters();
void pretty_print(std::ostream& stream) const override;
Tensor forward(const Tensor& input1, const Tensor& input2);
BilinearOptions options;
Tensor weight;
Tensor bias;
};
TORCH_MODULE(Bilinear);
} // namespace nn
} // namespace torch