Shortcuts

torch.library

torch.library is a collection of APIs for extending PyTorch’s core library of operators. It contains utilities for creating new custom operators as well as extending operators defined with PyTorch’s C++ operator registration APIs (e.g. aten operators).

For a detailed guide on effectively using these APIs, please see this gdoc

Creating new custom ops in Python

Use torch.library.custom_op() to create new custom ops.

torch.library.custom_op(name, /, *, mutates_args, device_types=None)

Wraps a function into custom operator.

Reasons why you may want to create a custom op include: - Wrapping a third-party library or custom kernel to work with PyTorch subsystems like Autograd. - Preventing torch.compile/export/FX tracing from peeking inside your function.

This API is used as a decorator around a function (please see examples). The provided function must have type hints; these are needed to interface with PyTorch’s various subsystems.

Parameters
  • name (str) – A name for the custom op that looks like “{namespace}::{name}”, e.g. “mylib::my_linear”. The name is used as the op’s stable identifier in PyTorch subsystems (e.g. torch.export, FX graphs). To avoid name collisions, please use your project name as the namespace; e.g. all custom ops in pytorch/fbgemm use “fbgemm” as the namespace.

  • mutates_args (Iterable[str]) – The names of args that the function mutates. This MUST be accurate, otherwise, the behavior is undefined.

  • device_types (None | str | Sequence[str]) – The device type(s) the function is valid for. If no device type is provided, then the function is used as the default implementation for all device types. Examples: “cpu”, “cuda”.

Return type

Callable

Examples::
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>> import numpy as np
>>>
>>> @custom_op("mylib::numpy_sin", mutates_args=())
>>> def numpy_sin(x: Tensor) -> Tensor:
>>>     x_np = x.cpu().numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> x = torch.randn(3)
>>> y = numpy_sin(x)
>>> assert torch.allclose(y, x.sin())
>>>
>>> # Example of a custom op that only works for one device type.
>>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu")
>>> def numpy_sin_cpu(x: Tensor) -> Tensor:
>>>     x_np = x.numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np)
>>>
>>> x = torch.randn(3)
>>> y = numpy_sin_cpu(x)
>>> assert torch.allclose(y, x.sin())
>>>
>>> # Example of a custom op that mutates an input
>>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu")
>>> def numpy_sin_inplace(x: Tensor) -> None:
>>>     x_np = x.numpy()
>>>     np.sin(x_np, out=x_np)
>>>
>>> x = torch.randn(3)
>>> expected = x.sin()
>>> numpy_sin_inplace(x)
>>> assert torch.allclose(x, expected)

Extending custom ops (created from Python or C++)

Use the impl methods, such as torch.library.impl() and func:torch.library.impl_abstract, to add implementations for any operators (they may have been created using torch.library.custom_op() or via PyTorch’s C++ operator registration APIs).

torch.library.impl(qualname, types, func=None, *, lib=None)[source]
torch.library.impl(lib, name, dispatch_key='')

Register an implementation for a device type for this operator.

You may pass “default” for types to register this implementation as the default implementation for ALL device types. Please only use this if the implementation truly supports all device types; for example, this is true if it is a composition of built-in PyTorch operators.

Some valid types are: “cpu”, “cuda”, “xla”, “mps”, “ipu”, “xpu”.

Parameters
  • qualname (str) – Should be a string that looks like “namespace::operator_name”.

  • types (str | Sequence[str]) – The device types to register an impl to.

  • lib (Optional[Library]) – If provided, the lifetime of this registration will be tied to the lifetime of the Library object.

Examples

>>> import torch
>>> import numpy as np
>>>
>>> # Define the operator
>>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor")
>>>
>>> # Add implementations for the cpu device
>>> @torch.library.impl("mylib::mysin", "cpu")
>>> def f(x):
>>>     return torch.from_numpy(np.sin(x.numpy()))
>>>
>>> x = torch.randn(3)
>>> y = torch.ops.mylib.mysin(x)
>>> assert torch.allclose(y, x.sin())
torch.library.register_autograd(op, setup_context_fn, backward_fn, /, *, lib=None)[source]

Register a backward formula for this custom op.

In order for an operator to work with autograd, you need to register a backward formula: 1. You must tell us how to compute gradients during the backward pass by providing us a “backward” function. 2. If you need any values from the forward to compute gradients, you can use setup_context to save values for backward.

backward_fn runs during the backward pass. It accepts (ctx, *grads): - grads is one or more gradients. The number of gradients matches the number of outputs of the operator. The ctx object is the same ctx object used by torch.autograd.Function. The semantics of backward_fn are the same as torch.autograd.Function.backward().

setup_context_fn(ctx, inputs, output) runs during the forward pass. Please save quantities needed for backward onto the ctx object via either torch.autograd.function.FunctionCtx.save_for_backward() or assigning them as attributes of ctx.

Both setup_context_fn and backward_fn must be traceable. That is, they may not directly access torch.Tensor.data_ptr() and they must not depend on or mutate global state. If you need a non-traceable backward, you can make it a separate custom_op that you call inside backward_fn.

Examples

>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>>
>>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
>>> def numpy_sin(x: Tensor) -> Tensor:
>>>     x_np = x.cpu().numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> def setup_context(ctx, inputs, output) -> Tensor:
>>>     x, = inputs
>>>     ctx.save_for_backward(x)
>>>
>>> def backward(ctx, grad):
>>>     x, = ctx.saved_tensors
>>>     return grad * x.cos()
>>>
>>> torch.library.register_autograd("mylib::numpy_sin", setup_context, backward)
>>>
>>> x = torch.randn(3, requires_grad=True)
>>> y = numpy_sin(x)
>>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y))
>>> assert torch.allclose(grad_x, x.cos())
torch.library.register_fake(op, func=None, /, *, lib=None, _stacklevel=1)[source]

Register a FakeTensor implementation (“fake impl”) for this operator.

Also sometimes known as a “meta kernel”, “abstract impl”.

An “FakeTensor implementation” specifies the behavior of this operator on Tensors that carry no data (“FakeTensor”). Given some input Tensors with certain properties (sizes/strides/storage_offset/device), it specifies what the properties of the output Tensors are.

The FakeTensor implementation has the same signature as the operator. It is run for both FakeTensors and meta tensors. To write a FakeTensor implementation, assume that all Tensor inputs to the operator are regular CPU/CUDA/Meta tensors, but they do not have storage, and you are trying to return regular CPU/CUDA/Meta tensor(s) as output. The FakeTensor implementation must consist of only PyTorch operations (and may not directly access the storage or data of any input or intermediate Tensors).

This API may be used as a decorator (see examples).

For a detailed guide on custom ops, please see https://docs.google.com/document/d/1W–T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ/edit

Examples

>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>>
>>> # Example 1: an operator without data-dependent output shape
>>> @torch.library.custom_op("mylib::custom_linear", mutates_args=())
>>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
>>>     raise NotImplementedError("Implementation goes here")
>>>
>>> @torch.library.register_fake("mylib::custom_linear")
>>> def _(x, weight, bias):
>>>     assert x.dim() == 2
>>>     assert weight.dim() == 2
>>>     assert bias.dim() == 1
>>>     assert x.shape[1] == weight.shape[1]
>>>     assert weight.shape[0] == bias.shape[0]
>>>     assert x.device == weight.device
>>>
>>>     return (x @ weight.t()) + bias
>>>
>>> with torch._subclasses.fake_tensor.FakeTensorMode():
>>>     x = torch.randn(2, 3)
>>>     w = torch.randn(3, 3)
>>>     b = torch.randn(3)
>>>     y = torch.ops.mylib.custom_linear(x, w, b)
>>>
>>> assert y.shape == (2, 3)
>>>
>>> # Example 2: an operator with data-dependent output shape
>>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=())
>>> def custom_nonzero(x: Tensor) -> Tensor:
>>>     x_np = x.numpy(force=True)
>>>     res = np.stack(np.nonzero(x_np), axis=1)
>>>     return torch.tensor(res, device=x.device)
>>>
>>> @torch.library.register_fake("mylib::custom_nonzero")
>>> def _(x):
>>>     # Number of nonzero-elements is data-dependent.
>>>     # Since we cannot peek at the data in an fake impl,
>>>     # we use the ctx object to construct a new symint that
>>>     # represents the data-dependent size.
>>>     ctx = torch.library.get_ctx()
>>>     nnz = ctx.new_dynamic_size()
>>>     shape = [nnz, x.dim()]
>>>     result = x.new_empty(shape, dtype=torch.int64)
>>>     return result
>>>
>>> from torch.fx.experimental.proxy_tensor import make_fx
>>>
>>> x = torch.tensor([0, 1, 2, 3, 4, 0])
>>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x)
>>> trace.print_readable()
>>>
>>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
torch.library.impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1)[source]

This API was renamed to torch.library.register_fake() in PyTorch 2.4. Please use that instead.

torch.library.get_ctx()[source]

get_ctx() returns the current AbstractImplCtx object.

Calling get_ctx() is only valid inside of an fake impl (see torch.library.register_fake() for more usage details.

Return type

AbstractImplCtx

Low-level APIs

The following APIs are direct bindings to PyTorch’s C++ low-level operator registration APIs.

Warning

The low-level operator registration APIs and the PyTorch Dispatcher are a complicated PyTorch concept. We recommend you use the higher level APIs above (that do not require a torch.library.Library object) when possible. This blog post <http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/>`_ is a good starting point to learn about the PyTorch Dispatcher.

A tutorial that walks you through some examples on how to use this API is available on Google Colab.

class torch.library.Library(ns, kind, dispatch_key='')[source]

A class to create libraries that can be used to register new operators or override operators in existing libraries from Python. A user can optionally pass in a dispatch keyname if they only want to register kernels corresponding to only one specific dispatch key.

To create a library to override operators in an existing library (with name ns), set the kind to “IMPL”. To create a new library (with name ns) to register new operators, set the kind to “DEF”. To create a fragment of a possibly existing library to register operators (and bypass the limitation that there is only one library for a given namespace), set the kind to “FRAGMENT”.

Parameters
  • ns – library name

  • kind – “DEF”, “IMPL” (default: “IMPL”), “FRAGMENT”

  • dispatch_key – PyTorch dispatch key (default: “”)

define(schema, alias_analysis='', *, tags=())[source]

Defines a new operator and its semantics in the ns namespace.

Parameters
  • schema – function schema to define a new operator.

  • alias_analysis (optional) – Indicates if the aliasing properties of the operator arguments can be inferred from the schema (default behavior) or not (“CONSERVATIVE”).

  • tags (Tag | Sequence[Tag]) – one or more torch.Tag to apply to this operator. Tagging an operator changes the operator’s behavior under various PyTorch subsystems; please read the docs for the torch.Tag carefully before applying it.

Returns

name of the operator as inferred from the schema.

Example::
>>> my_lib = Library("mylib", "DEF")
>>> my_lib.define("sum(Tensor self) -> Tensor")
impl(op_name, fn, dispatch_key='', *, with_keyset=False)[source]

Registers the function implementation for an operator defined in the library.

Parameters
  • op_name – operator name (along with the overload) or OpOverload object.

  • fn – function that’s the operator implementation for the input dispatch key or fallthrough_kernel() to register a fallthrough.

  • dispatch_key – dispatch key that the input function should be registered for. By default, it uses the dispatch key that the library was created with.

Example::
>>> my_lib = Library("aten", "IMPL")
>>> def div_cpu(self, other):
>>>     return self * (1 / other)
>>> my_lib.impl("div.Tensor", div_cpu, "CPU")
torch.library.fallthrough_kernel()[source]

A dummy function to pass to Library.impl in order to register a fallthrough.

torch.library.define(qualname, schema, *, lib=None, tags=())[source]
torch.library.define(lib, schema, alias_analysis='')

Defines a new operator.

In PyTorch, defining an op (short for “operator”) is a two step-process: - we need to define the op (by providing an operator name and schema) - we need to implement behavior for how the operator interacts with various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.

This entrypoint defines the custom operator (the first step) you must then perform the second step by calling various impl_* APIs, like torch.library.impl() or torch.library.register_fake().

Parameters
  • qualname (str) – The qualified name for the operator. Should be a string that looks like “namespace::name”, e.g. “aten::sin”. Operators in PyTorch need a namespace to avoid name collisions; a given operator may only be created once. If you are writing a Python library, we recommend the namespace to be the name of your top-level module.

  • schema (str) – The schema of the operator. E.g. “(Tensor x) -> Tensor” for an op that accepts one Tensor and returns one Tensor. It does not contain the operator name (that is passed in qualname).

  • lib (Optional[Library]) – If provided, the lifetime of this operator will be tied to the lifetime of the Library object.

  • tags (Tag | Sequence[Tag]) – one or more torch.Tag to apply to this operator. Tagging an operator changes the operator’s behavior under various PyTorch subsystems; please read the docs for the torch.Tag carefully before applying it.

Example::
>>> import torch
>>> import numpy as np
>>>
>>> # Define the operator
>>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor")
>>>
>>> # Add implementations for the operator
>>> @torch.library.impl("mylib::sin", "cpu")
>>> def f(x):
>>>     return torch.from_numpy(np.sin(x.numpy()))
>>>
>>> # Call the new operator from torch.ops.
>>> x = torch.randn(3)
>>> y = torch.ops.mylib.sin(x)
>>> assert torch.allclose(y, x.sin())

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources