Shortcuts

Program Listing for File linear.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/linear.h)

#pragma once

#include <torch/nn/cloneable.h>
#include <torch/nn/functional/linear.h>
#include <torch/nn/module.h>
#include <torch/nn/options/linear.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>

#include <cstddef>
#include <vector>

namespace torch {
namespace nn {

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Identity ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API IdentityImpl : public Cloneable<IdentityImpl> {
 public:
  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input);
};

TORCH_MODULE(Identity);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Linear ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API LinearImpl : public Cloneable<LinearImpl> {
 public:
  LinearImpl(int64_t in_features, int64_t out_features)
      : LinearImpl(LinearOptions(in_features, out_features)) {}
  explicit LinearImpl(const LinearOptions& options_);

  void reset() override;

  void reset_parameters();

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input);

  LinearOptions options;

  Tensor weight;

  Tensor bias;
};

TORCH_MODULE(Linear);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Flatten ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API FlattenImpl : public Cloneable<FlattenImpl> {
 public:
  explicit FlattenImpl(const FlattenOptions& options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input);

  FlattenOptions options;
};

TORCH_MODULE(Flatten);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Unflatten
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API UnflattenImpl : public Cloneable<UnflattenImpl> {
 public:
  UnflattenImpl(int64_t dim, std::vector<int64_t> sizes)
      : UnflattenImpl(UnflattenOptions(dim, sizes)) {}
  UnflattenImpl(std::string dimname, UnflattenOptions::namedshape_t namedshape)
      : UnflattenImpl(UnflattenOptions(dimname, namedshape)) {}
  explicit UnflattenImpl(UnflattenOptions options_);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input);

  UnflattenOptions options;
};

TORCH_MODULE(Unflatten);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Bilinear ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API BilinearImpl : public Cloneable<BilinearImpl> {
 public:
  BilinearImpl(int64_t in1_features, int64_t in2_features, int64_t out_features)
      : BilinearImpl(
            BilinearOptions(in1_features, in2_features, out_features)) {}
  explicit BilinearImpl(const BilinearOptions& options_);

  void reset() override;

  void reset_parameters();

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input1, const Tensor& input2);

  BilinearOptions options;

  Tensor weight;

  Tensor bias;
};

TORCH_MODULE(Bilinear);

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources