torch.library¶
torch.library is a collection of APIs for extending PyTorch’s core library of operators. It contains utilities for testing custom operators, creating new custom operators, and extending operators defined with PyTorch’s C++ operator registration APIs (e.g. aten operators).
For a detailed guide on effectively using these APIs, please see PyTorch Custom Operators Landing Page for more details on how to effectively use these APIs.
Testing custom ops¶
Use torch.library.opcheck()
to test custom ops for incorrect usage of the
Python torch.library and/or C++ TORCH_LIBRARY APIs. Also, if your operator supports
training, use torch.autograd.gradcheck()
to test that the gradients are
mathematically correct.
- torch.library.opcheck(op, args, kwargs=None, *, test_utils=('test_schema', 'test_autograd_registration', 'test_faketensor', 'test_aot_dispatch_dynamic'), raise_exception=True)[source]¶
Given an operator and some sample arguments, tests if the operator is registered correctly.
That is, when you use the torch.library/TORCH_LIBRARY APIs to create a custom op, you specified metadata (e.g. mutability info) about the custom op and these APIs require that the functions you pass them satisfy certain properties (e.g. no data pointer access in the fake/meta/abstract kernel)
opcheck
tests these metadata and properties.Concretely, we test the following:
test_schema: If the schema matches the implementation of the operator. For example: if the schema specifies a Tensor is mutated, then we check the implementation mutates the Tensor. If the schema specifies that we return a new Tensor, then we check that the implementation returns a new Tensor (instead of an existing one or a view of an existing one).
test_autograd_registration: If the operator supports training (autograd): we check that its autograd formula is registered via torch.library.register_autograd or a manual registration to one or more DispatchKey::Autograd keys. Any other DispatchKey-based registrations may lead to undefined behavior.
test_faketensor: If the operator has a FakeTensor kernel (and if it is correct). The FakeTensor kernel is necessary ( but not sufficient) for the operator to work with PyTorch compilation APIs (torch.compile/export/FX). We check that a FakeTensor kernel (also sometimes known as a meta kernel) was registered for the operator and that it is correct. This test takes the result of running the operator on real tensors and the result of running the operator on FakeTensors and checks that they have the same Tensor metadata (sizes/strides/dtype/device/etc).
test_aot_dispatch_dynamic: If the operator has correct behavior with PyTorch compilation APIs (torch.compile/export/FX). This checks that the outputs (and gradients, if applicable) are the same under eager-mode PyTorch and torch.compile. This test is a superset of
test_faketensor
and is an e2e test; other things it tests are that the operator supports functionalization and that the backward pass (if it exists) also supports FakeTensor and functionalization.
For best results, please call
opcheck
multiple times with a representative set of inputs. If your operator supports autograd, please useopcheck
with inputs withrequires_grad = True
; if your operator supports multiple devices (e.g. CPU and CUDA), please useopcheck
with inputs on all supported devices.- Parameters
op (Union[OpOverload, OpOverloadPacket, CustomOpDef]) – The operator. Must either be a function decorated with
torch.library.custom_op()
or an OpOverload/OpOverloadPacket found in torch.ops.* (e.g. torch.ops.aten.sin, torch.ops.mylib.foo)kwargs (Optional[Dict[str, Any]]) – The kwargs to the operator
test_utils (Union[str, Sequence[str]]) – Tests that we should run. Default: all of them. Example: (“test_schema”, “test_faketensor”)
raise_exception (bool) – If we should raise an exception on the first error. If False, we will return a dict with information on if each test passed or not.
- Return type
Warning
opcheck and
torch.autograd.gradcheck()
test different things; opcheck tests if your usage of torch.library APIs is correct whiletorch.autograd.gradcheck()
tests if your autograd formula is mathematically correct. Use both to test custom ops that support gradient computation.Example
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) >>> def numpy_add(x: Tensor, y: float) -> Tensor: >>> x_np = x.numpy(force=True) >>> z_np = x_np + y >>> return torch.from_numpy(z_np).to(x.device) >>> >>> @numpy_sin.register_fake >>> def _(x, y): >>> return torch.empty_like(x) >>> >>> def setup_context(ctx, inputs, output): >>> y, = inputs >>> ctx.y = y >>> >>> def backward(ctx, grad): >>> return grad * ctx.y, None >>> >>> numpy_sin.register_autograd(backward, setup_context=setup_context) >>> >>> sample_inputs = [ >>> (torch.randn(3), 3.14), >>> (torch.randn(2, 3, device='cuda'), 2.718), >>> (torch.randn(1, 10, requires_grad=True), 1.234), >>> (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18), >>> ] >>> >>> for args in sample_inputs: >>> torch.library.opcheck(foo, args)
Creating new custom ops in Python¶
Use torch.library.custom_op()
to create new custom ops.
- torch.library.custom_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None)¶
Wraps a function into custom operator.
Reasons why you may want to create a custom op include: - Wrapping a third-party library or custom kernel to work with PyTorch subsystems like Autograd. - Preventing torch.compile/export/FX tracing from peeking inside your function.
This API is used as a decorator around a function (please see examples). The provided function must have type hints; these are needed to interface with PyTorch’s various subsystems.
- Parameters
name (str) – A name for the custom op that looks like “{namespace}::{name}”, e.g. “mylib::my_linear”. The name is used as the op’s stable identifier in PyTorch subsystems (e.g. torch.export, FX graphs). To avoid name collisions, please use your project name as the namespace; e.g. all custom ops in pytorch/fbgemm use “fbgemm” as the namespace.
mutates_args (Iterable[str] or "unknown") – The names of args that the function mutates. This MUST be accurate, otherwise, the behavior is undefined. If “unknown”, it pessimistically assumes that all inputs to the operator are being mutated.
device_types (None | str | Sequence[str]) – The device type(s) the function is valid for. If no device type is provided, then the function is used as the default implementation for all device types. Examples: “cpu”, “cuda”. When registering a device-specific implementation for an operator that accepts no Tensors, we require the operator to have a “device: torch.device argument”.
schema (None | str) – A schema string for the operator. If None (recommended) we’ll infer a schema for the operator from its type annotations. We recommend letting us infer a schema unless you have a specific reason not to. Example: “(Tensor x, int y) -> (Tensor, Tensor)”.
- Return type
Note
We recommend not passing in a
schema
arg and instead letting us infer it from the type annotations. It is error-prone to write your own schema. You may wish to provide your own schema if our interpretation of the type annotation is not what you want. For more info on how to write a schema string, see here- Examples::
>>> import torch >>> from torch import Tensor >>> from torch.library import custom_op >>> import numpy as np >>> >>> @custom_op("mylib::numpy_sin", mutates_args=()) >>> def numpy_sin(x: Tensor) -> Tensor: >>> x_np = x.cpu().numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np).to(device=x.device) >>> >>> x = torch.randn(3) >>> y = numpy_sin(x) >>> assert torch.allclose(y, x.sin()) >>> >>> # Example of a custom op that only works for one device type. >>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu") >>> def numpy_sin_cpu(x: Tensor) -> Tensor: >>> x_np = x.numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np) >>> >>> x = torch.randn(3) >>> y = numpy_sin_cpu(x) >>> assert torch.allclose(y, x.sin()) >>> >>> # Example of a custom op that mutates an input >>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu") >>> def numpy_sin_inplace(x: Tensor) -> None: >>> x_np = x.numpy() >>> np.sin(x_np, out=x_np) >>> >>> x = torch.randn(3) >>> expected = x.sin() >>> numpy_sin_inplace(x) >>> assert torch.allclose(x, expected) >>> >>> # Example of a factory function >>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu") >>> def bar(device: torch.device) -> Tensor: >>> return torch.ones(3) >>> >>> bar("cpu")
Extending custom ops (created from Python or C++)¶
Use the register.* methods, such as torch.library.register_kernel()
and
func:torch.library.register_fake, to add implementations
for any operators (they may have been created using torch.library.custom_op()
or
via PyTorch’s C++ operator registration APIs).
- torch.library.register_kernel(op, device_types, func=None, /, *, lib=None)[source]¶
Register an implementation for a device type for this operator.
Some valid device_types are: “cpu”, “cuda”, “xla”, “mps”, “ipu”, “xpu”. This API may be used as a decorator.
- Parameters
op (str | OpOverload) – The operator to register an impl to.
device_types (None | str | Sequence[str]) – The device_types to register an impl to. If None, we will register to all device types – please only use this option if your implementation is truly device-type-agnostic.
func (Callable) – The function to register as the implementation for the given device types.
lib (Optional[Library]) – If provided, the lifetime of this registration
- Examples::
>>> import torch >>> from torch import Tensor >>> from torch.library import custom_op >>> import numpy as np >>> >>> # Create a custom op that works on cpu >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu") >>> def numpy_sin(x: Tensor) -> Tensor: >>> x_np = x.numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np) >>> >>> # Add implementations for the cuda device >>> @torch.library.register_kernel("mylib::numpy_sin", "cuda") >>> def _(x): >>> x_np = x.cpu().numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np).to(device=x.device) >>> >>> x_cpu = torch.randn(3) >>> x_cuda = x_cpu.cuda() >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin()) >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
- torch.library.register_autograd(op, backward, /, *, setup_context=None, lib=None)[source]¶
Register a backward formula for this custom op.
In order for an operator to work with autograd, you need to register a backward formula: 1. You must tell us how to compute gradients during the backward pass by providing us a “backward” function. 2. If you need any values from the forward to compute gradients, you can use setup_context to save values for backward.
backward
runs during the backward pass. It accepts(ctx, *grads)
: -grads
is one or more gradients. The number of gradients matches the number of outputs of the operator. Thectx
object is the same ctx object used bytorch.autograd.Function
. The semantics ofbackward_fn
are the same astorch.autograd.Function.backward()
.setup_context(ctx, inputs, output)
runs during the forward pass. Please save quantities needed for backward onto thectx
object via eithertorch.autograd.function.FunctionCtx.save_for_backward()
or assigning them as attributes ofctx
. If your custom op has kwarg-only arguments, we expect the signature ofsetup_context
to besetup_context(ctx, inputs, keyword_only_inputs, output)
.Both
setup_context_fn
andbackward_fn
must be traceable. That is, they may not directly accesstorch.Tensor.data_ptr()
and they must not depend on or mutate global state. If you need a non-traceable backward, you can make it a separate custom_op that you call insidebackward_fn
.Examples
>>> import torch >>> import numpy as np >>> from torch import Tensor >>> >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=()) >>> def numpy_sin(x: Tensor) -> Tensor: >>> x_np = x.cpu().numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np).to(device=x.device) >>> >>> def setup_context(ctx, inputs, output) -> Tensor: >>> x, = inputs >>> ctx.save_for_backward(x) >>> >>> def backward(ctx, grad): >>> x, = ctx.saved_tensors >>> return grad * x.cos() >>> >>> torch.library.register_autograd( ... "mylib::numpy_sin", backward, setup_context=setup_context ... ) >>> >>> x = torch.randn(3, requires_grad=True) >>> y = numpy_sin(x) >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) >>> assert torch.allclose(grad_x, x.cos()) >>> >>> # Example with a keyword-only arg >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor: >>> x_np = x.cpu().numpy() >>> y_np = x_np * val >>> return torch.from_numpy(y_np).to(device=x.device) >>> >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor: >>> ctx.val = keyword_only_inputs["val"] >>> >>> def backward(ctx, grad): >>> return grad * ctx.val >>> >>> torch.library.register_autograd( ... "mylib::numpy_mul", backward, setup_context=setup_context ... ) >>> >>> x = torch.randn(3, requires_grad=True) >>> y = numpy_mul(x, val=3.14) >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
- torch.library.register_fake(op, func=None, /, *, lib=None, _stacklevel=1)[source]¶
Register a FakeTensor implementation (“fake impl”) for this operator.
Also sometimes known as a “meta kernel”, “abstract impl”.
An “FakeTensor implementation” specifies the behavior of this operator on Tensors that carry no data (“FakeTensor”). Given some input Tensors with certain properties (sizes/strides/storage_offset/device), it specifies what the properties of the output Tensors are.
The FakeTensor implementation has the same signature as the operator. It is run for both FakeTensors and meta tensors. To write a FakeTensor implementation, assume that all Tensor inputs to the operator are regular CPU/CUDA/Meta tensors, but they do not have storage, and you are trying to return regular CPU/CUDA/Meta tensor(s) as output. The FakeTensor implementation must consist of only PyTorch operations (and may not directly access the storage or data of any input or intermediate Tensors).
This API may be used as a decorator (see examples).
For a detailed guide on custom ops, please see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
Examples
>>> import torch >>> import numpy as np >>> from torch import Tensor >>> >>> # Example 1: an operator without data-dependent output shape >>> @torch.library.custom_op("mylib::custom_linear", mutates_args=()) >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: >>> raise NotImplementedError("Implementation goes here") >>> >>> @torch.library.register_fake("mylib::custom_linear") >>> def _(x, weight, bias): >>> assert x.dim() == 2 >>> assert weight.dim() == 2 >>> assert bias.dim() == 1 >>> assert x.shape[1] == weight.shape[1] >>> assert weight.shape[0] == bias.shape[0] >>> assert x.device == weight.device >>> >>> return (x @ weight.t()) + bias >>> >>> with torch._subclasses.fake_tensor.FakeTensorMode(): >>> x = torch.randn(2, 3) >>> w = torch.randn(3, 3) >>> b = torch.randn(3) >>> y = torch.ops.mylib.custom_linear(x, w, b) >>> >>> assert y.shape == (2, 3) >>> >>> # Example 2: an operator with data-dependent output shape >>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=()) >>> def custom_nonzero(x: Tensor) -> Tensor: >>> x_np = x.numpy(force=True) >>> res = np.stack(np.nonzero(x_np), axis=1) >>> return torch.tensor(res, device=x.device) >>> >>> @torch.library.register_fake("mylib::custom_nonzero") >>> def _(x): >>> # Number of nonzero-elements is data-dependent. >>> # Since we cannot peek at the data in an fake impl, >>> # we use the ctx object to construct a new symint that >>> # represents the data-dependent size. >>> ctx = torch.library.get_ctx() >>> nnz = ctx.new_dynamic_size() >>> shape = [nnz, x.dim()] >>> result = x.new_empty(shape, dtype=torch.int64) >>> return result >>> >>> from torch.fx.experimental.proxy_tensor import make_fx >>> >>> x = torch.tensor([0, 1, 2, 3, 4, 0]) >>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x) >>> trace.print_readable() >>> >>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
- torch.library.register_vmap(op, func=None, /, *, lib=None)[source]¶
Register a vmap implementation to support
torch.vmap()
for this custom op.This API may be used as a decorator (see examples).
In order for an operator to work with
torch.vmap()
, you may need to register a vmap implementation in the following signature:vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)
,where
*args
and**kwargs
are the arguments and kwargs forop
. We do not support kwarg-only Tensor args.It specifies how do we compute the batched version of
op
given inputs with an additional dimension (specified byin_dims
).For each arg in
args
,in_dims
has a correspondingOptional[int]
. It isNone
if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer specifying what dimension of the Tensor is being vmapped over.info
is a collection of additional metadata that may be helpful:info.batch_size
specifies the size of the dimension being vmapped over, whileinfo.randomness
is therandomness
option that was passed totorch.vmap()
.The return of the function
func
is a tuple of(output, out_dims)
. Similar toin_dims
,out_dims
should be of the same structure asoutput
and contain oneout_dim
per output that specifies if the output has the vmapped dimension and what index it is in.Examples
>>> import torch >>> import numpy as np >>> from torch import Tensor >>> from typing import Tuple >>> >>> def to_numpy(tensor): >>> return tensor.cpu().numpy() >>> >>> lib = torch.library.Library("mylib", "FRAGMENT") >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=()) >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]: >>> x_np = to_numpy(x) >>> dx = torch.tensor(3 * x_np ** 2, device=x.device) >>> return torch.tensor(x_np ** 3, device=x.device), dx >>> >>> def numpy_cube_vmap(info, in_dims, x): >>> result = numpy_cube(x) >>> return result, (in_dims[0], in_dims[0]) >>> >>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap) >>> >>> x = torch.randn(3) >>> torch.vmap(numpy_cube)(x) >>> >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor: >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) >>> >>> @torch.library.register_vmap("mylib::numpy_mul") >>> def numpy_mul_vmap(info, in_dims, x, y): >>> x_bdim, y_bdim = in_dims >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) >>> result = x * y >>> result = result.movedim(-1, 0) >>> return result, 0 >>> >>> >>> x = torch.randn(3) >>> y = torch.randn(3) >>> torch.vmap(numpy_mul)(x, y)
Note
The vmap function should aim to preserve the semantics of the entire custom operator. That is,
grad(vmap(op))
should be replaceable with agrad(map(op))
.If your custom operator has any custom behavior in the backward pass, please keep this in mind.
- torch.library.impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1)[source]¶
This API was renamed to
torch.library.register_fake()
in PyTorch 2.4. Please use that instead.
- torch.library.get_ctx()[source]¶
get_ctx() returns the current AbstractImplCtx object.
Calling
get_ctx()
is only valid inside of an fake impl (seetorch.library.register_fake()
for more usage details.- Return type
FakeImplCtx
- torch.library.register_torch_dispatch(op, torch_dispatch_class, func=None, /, *, lib=None)[source]¶
Registers a torch_dispatch rule for the given operator and
torch_dispatch_class
.This allows for open registration to specify the behavior between the operator and the
torch_dispatch_class
without needing to modify thetorch_dispatch_class
or the operator directly.The
torch_dispatch_class
is either a Tensor subclass with__torch_dispatch__
or a TorchDispatchMode.If it is a Tensor subclass, we expect
func
to have the following signature:(cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any
If it is a TorchDispatchMode, we expect
func
to have the following signature:(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any
args
andkwargs
will have been normalized the same way they are in__torch_dispatch__
(see __torch_dispatch__ calling convention).Examples
>>> import torch >>> >>> @torch.library.custom_op("mylib::foo", mutates_args={}) >>> def foo(x: torch.Tensor) -> torch.Tensor: >>> return x.clone() >>> >>> class MyMode(torch.utils._python_dispatch.TorchDispatchMode): >>> def __torch_dispatch__(self, func, types, args=(), kwargs=None): >>> return func(*args, **kwargs) >>> >>> @torch.library.register_torch_dispatch("mylib::foo", MyMode) >>> def _(mode, func, types, args, kwargs): >>> x, = args >>> return x + 1 >>> >>> x = torch.randn(3) >>> y = foo(x) >>> assert torch.allclose(y, x) >>> >>> with MyMode(): >>> y = foo(x) >>> assert torch.allclose(y, x + 1)
- torch.library.infer_schema(prototype_function, /, *, mutates_args, op_name=None)¶
Parses the schema of a given function with type hints. The schema is inferred from the function’s type hints, and can be used to define a new operator.
We make the following assumptions:
None of the outputs alias any of the inputs or each other.
- String type annotations “device, dtype, Tensor, types” without library specification areassumed to be torch.*. Similarly, string type annotations “Optional, List, Sequence, Union”without library specification are assumed to be typing.*.
- Only the args listed in
mutates_args
are being mutated. Ifmutates_args
is “unknown”,it assumes that all inputs to the operator are being mutates.
Callers (e.g. the custom ops API) are responsible for checking these assumptions.
- Parameters
prototype_function (Callable) – The function from which to infer a schema for from its type annotations.
op_name (Optional[str]) – The name of the operator in the schema. If
name
is None, then the name is not included in the inferred schema. Note that the input schema totorch.library.Library.define
requires a operator name.mutates_args ("unknown" | Iterable[str]) – The arguments that are mutated in the function.
- Returns
The inferred schema.
- Return type
Example
>>> def foo_impl(x: torch.Tensor) -> torch.Tensor: >>> return x.sin() >>> >>> infer_schema(foo_impl, op_name="foo", mutates_args={}) foo(Tensor x) -> Tensor >>> >>> infer_schema(foo_impl, mutates_args={}) (Tensor x) -> Tensor
- class torch._library.custom_ops.CustomOpDef(namespace, name, schema, fn)[source]¶
CustomOpDef is a wrapper around a function that turns it into a custom op.
It has various methods for registering additional behavior for this custom op.
You should not instantiate CustomOpDef directly; instead, use the
torch.library.custom_op()
API.- set_kernel_enabled(device_type, enabled=True)[source]¶
Disable or re-enable an already registered kernel for this custom operator.
If the kernel is already disabled/enabled, this is a no-op.
Note
If a kernel is first disabled and then registered, it is disabled until enabled again.
- Parameters
Example
>>> inp = torch.randn(1) >>> >>> # define custom op `f`. >>> @custom_op("mylib::f", mutates_args=()) >>> def f(x: Tensor) -> Tensor: >>> return torch.zeros(1) >>> >>> print(f(inp)) # tensor([0.]), default kernel >>> >>> @f.register_kernel("cpu") >>> def _(x): >>> return torch.ones(1) >>> >>> print(f(inp)) # tensor([1.]), CPU kernel >>> >>> # temporarily disable the CPU kernel >>> with f.set_kernel_enabled("cpu", enabled = False): >>> print(f(inp)) # tensor([0.]) with CPU kernel disabled
Low-level APIs¶
The following APIs are direct bindings to PyTorch’s C++ low-level operator registration APIs.
Warning
The low-level operator registration APIs and the PyTorch Dispatcher are a complicated PyTorch concept. We recommend you use the higher level APIs above (that do not require a torch.library.Library object) when possible. This blog post <http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/>`_ is a good starting point to learn about the PyTorch Dispatcher.
A tutorial that walks you through some examples on how to use this API is available on Google Colab.
- class torch.library.Library(ns, kind, dispatch_key='')[source]¶
A class to create libraries that can be used to register new operators or override operators in existing libraries from Python. A user can optionally pass in a dispatch keyname if they only want to register kernels corresponding to only one specific dispatch key.
To create a library to override operators in an existing library (with name ns), set the kind to “IMPL”. To create a new library (with name ns) to register new operators, set the kind to “DEF”. To create a fragment of a possibly existing library to register operators (and bypass the limitation that there is only one library for a given namespace), set the kind to “FRAGMENT”.
- Parameters
ns – library name
kind – “DEF”, “IMPL” (default: “IMPL”), “FRAGMENT”
dispatch_key – PyTorch dispatch key (default: “”)
- define(schema, alias_analysis='', *, tags=())[source]¶
Defines a new operator and its semantics in the ns namespace.
- Parameters
schema – function schema to define a new operator.
alias_analysis (optional) – Indicates if the aliasing properties of the operator arguments can be inferred from the schema (default behavior) or not (“CONSERVATIVE”).
tags (Tag | Sequence[Tag]) – one or more torch.Tag to apply to this operator. Tagging an operator changes the operator’s behavior under various PyTorch subsystems; please read the docs for the torch.Tag carefully before applying it.
- Returns
name of the operator as inferred from the schema.
- Example::
>>> my_lib = Library("mylib", "DEF") >>> my_lib.define("sum(Tensor self) -> Tensor")
- fallback(fn, dispatch_key='', *, with_keyset=False)[source]¶
Registers the function implementation as the fallback for the given key.
This function only works for a library with global namespace (“_”).
- Parameters
fn – function used as fallback for the given dispatch key or
fallthrough_kernel()
to register a fallthrough.dispatch_key – dispatch key that the input function should be registered for. By default, it uses the dispatch key that the library was created with.
with_keyset – flag controlling if the current dispatcher call keyset should be passed as the first argument to
fn
when calling. This should be used to create the appropriate keyset for redispatch calls.
- Example::
>>> my_lib = Library("_", "IMPL") >>> def fallback_kernel(op, *args, **kwargs): >>> # Handle all autocast ops generically >>> # ... >>> my_lib.fallback(fallback_kernel, "Autocast")
- impl(op_name, fn, dispatch_key='', *, with_keyset=False)[source]¶
Registers the function implementation for an operator defined in the library.
- Parameters
op_name – operator name (along with the overload) or OpOverload object.
fn – function that’s the operator implementation for the input dispatch key or
fallthrough_kernel()
to register a fallthrough.dispatch_key – dispatch key that the input function should be registered for. By default, it uses the dispatch key that the library was created with.
with_keyset – flag controlling if the current dispatcher call keyset should be passed as the first argument to
fn
when calling. This should be used to create the appropriate keyset for redispatch calls.
- Example::
>>> my_lib = Library("aten", "IMPL") >>> def div_cpu(self, other): >>> return self * (1 / other) >>> my_lib.impl("div.Tensor", div_cpu, "CPU")
- torch.library.fallthrough_kernel()[source]¶
A dummy function to pass to
Library.impl
in order to register a fallthrough.
- torch.library.define(qualname, schema, *, lib=None, tags=())[source]¶
- torch.library.define(lib, schema, alias_analysis='')
Defines a new operator.
In PyTorch, defining an op (short for “operator”) is a two step-process: - we need to define the op (by providing an operator name and schema) - we need to implement behavior for how the operator interacts with various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
This entrypoint defines the custom operator (the first step) you must then perform the second step by calling various
impl_*
APIs, liketorch.library.impl()
ortorch.library.register_fake()
.- Parameters
qualname (str) – The qualified name for the operator. Should be a string that looks like “namespace::name”, e.g. “aten::sin”. Operators in PyTorch need a namespace to avoid name collisions; a given operator may only be created once. If you are writing a Python library, we recommend the namespace to be the name of your top-level module.
schema (str) – The schema of the operator. E.g. “(Tensor x) -> Tensor” for an op that accepts one Tensor and returns one Tensor. It does not contain the operator name (that is passed in
qualname
).lib (Optional[Library]) – If provided, the lifetime of this operator will be tied to the lifetime of the Library object.
tags (Tag | Sequence[Tag]) – one or more torch.Tag to apply to this operator. Tagging an operator changes the operator’s behavior under various PyTorch subsystems; please read the docs for the torch.Tag carefully before applying it.
- Example::
>>> import torch >>> import numpy as np >>> >>> # Define the operator >>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor") >>> >>> # Add implementations for the operator >>> @torch.library.impl("mylib::sin", "cpu") >>> def f(x): >>> return torch.from_numpy(np.sin(x.numpy())) >>> >>> # Call the new operator from torch.ops. >>> x = torch.randn(3) >>> y = torch.ops.mylib.sin(x) >>> assert torch.allclose(y, x.sin())
- torch.library.impl(qualname, types, func=None, *, lib=None)[source]¶
- torch.library.impl(lib, name, dispatch_key='')
Register an implementation for a device type for this operator.
You may pass “default” for
types
to register this implementation as the default implementation for ALL device types. Please only use this if the implementation truly supports all device types; for example, this is true if it is a composition of built-in PyTorch operators.Some valid types are: “cpu”, “cuda”, “xla”, “mps”, “ipu”, “xpu”.
- Parameters
Examples
>>> import torch >>> import numpy as np >>> >>> # Define the operator >>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor") >>> >>> # Add implementations for the cpu device >>> @torch.library.impl("mylib::mysin", "cpu") >>> def f(x): >>> return torch.from_numpy(np.sin(x.numpy())) >>> >>> x = torch.randn(3) >>> y = torch.ops.mylib.mysin(x) >>> assert torch.allclose(y, x.sin())