Shortcuts

JitScalarType

class torch.onnx.JitScalarType(value)

Scalar types defined in torch.

Use JitScalarType to convert from torch and JIT scalar types to ONNX scalar types.

Examples

>>> JitScalarType.from_value(torch.ones(1, 2)).onnx_type()
TensorProtoDataType.FLOAT
>>> JitScalarType.from_value(torch_c_value_with_type_float).onnx_type()
TensorProtoDataType.FLOAT
>>> JitScalarType.from_dtype(torch.get_default_dtype).onnx_type()
TensorProtoDataType.FLOAT
dtype()[source]

Convert a JitScalarType to a torch dtype.

Return type

dtype

classmethod from_dtype(dtype)[source]

Convert a torch dtype to JitScalarType.

Note: DO NOT USE this API when dtype comes from a torch._C.Value.type() calls.

A “RuntimeError: INTERNAL ASSERT FAILED at “../aten/src/ATen/core/jit_type_base.h” can be raised in several scenarios where shape info is not present. Instead use from_value API which is safer.

Parameters

dtype (Optional[dtype]) – A torch.dtype to create a JitScalarType from

Returns

JitScalarType

Raises

OnnxExporterError – if dtype is not a valid torch.dtype or if it is None.

Return type

JitScalarType

classmethod from_value(value, default=None)[source]

Create a JitScalarType from an value’s scalar type.

Parameters
  • value (Union[None, Value, Tensor]) – An object to fetch scalar type from.

  • default – The JitScalarType to return if a valid scalar cannot be fetched from value

Returns

JitScalarType.

Raises
  • OnnxExporterError – if value does not have a valid scalar type and default is None.

  • SymbolicValueError – when value.type()’s info are empty and default is None

Return type

JitScalarType

onnx_compatible()[source]

Return whether this JitScalarType is compatible with ONNX.

Return type

bool

onnx_type()[source]

Convert a JitScalarType to an ONNX data type.

Return type

TensorProtoDataType

scalar_name()[source]

Convert a JitScalarType to a JIT scalar type name.

Return type

Literal[‘Byte’, ‘Char’, ‘Double’, ‘Float’, ‘Half’, ‘Int’, ‘Long’, ‘Short’, ‘Bool’, ‘ComplexHalf’, ‘ComplexFloat’, ‘ComplexDouble’, ‘QInt8’, ‘QUInt8’, ‘QInt32’, ‘BFloat16’, ‘Float8E5M2’, ‘Float8E4M3FN’, ‘Float8E5M2FNUZ’, ‘Float8E4M3FNUZ’, ‘Undefined’]

torch_name()[source]

Convert a JitScalarType to a torch type name.

Return type

Literal[‘bool’, ‘uint8_t’, ‘int8_t’, ‘double’, ‘float’, ‘half’, ‘int’, ‘int64_t’, ‘int16_t’, ‘complex32’, ‘complex64’, ‘complex128’, ‘qint8’, ‘quint8’, ‘qint32’, ‘bfloat16’, ‘float8_e5m2’, ‘float8_e4m3fn’, ‘float8_e5m2fnuz’, ‘float8_e4m3fnuz’]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources