Shortcuts

Program Listing for File linear.h

Return to documentation for file (torch/csrc/api/include/torch/nn/functional/linear.h)

#pragma once

#include <torch/types.h>

namespace torch {
namespace nn {
namespace functional {

inline Tensor bilinear(
    const Tensor& input1,
    const Tensor& input2,
    const Tensor& weight,
    const Tensor& bias = Tensor()) {
  return torch::bilinear(input1, input2, weight, bias);
}

// ============================================================================

inline Tensor linear(
    const Tensor& input,
    const Tensor& weight,
    const Tensor& bias = {}) {
  if (input.dim() == 2 && bias.defined()) {
    // fused op is marginally faster
    return torch::addmm(bias, input, weight.t());
  } else {
    auto output = input.matmul(weight.t());
    if (bias.defined()) {
      output += bias;
    }
    return output;
  }
}

} // namespace functional
} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources