Shortcuts

torch.set_default_dtype

torch.set_default_dtype(d)[source]

Sets the default floating point dtype to d. Supports torch.float32 and torch.float64 as inputs. Other dtypes may be accepted without complaint but are not supported and are unlikely to work as expected.

When PyTorch is initialized its default floating point dtype is torch.float32, and the intent of set_default_dtype(torch.float64) is to facilitate NumPy-like type inference. The default floating point dtype is used to:

  1. Implicitly determine the default complex dtype. When the default floating point type is float32 the default complex dtype is complex64, and when the default floating point type is float64 the default complex type is complex128.

  2. Infer the dtype for tensors constructed using Python floats or complex Python numbers. See examples below.

  3. Determine the result of type promotion between bool and integer tensors and Python floats and complex Python numbers.

Parameters:

d (torch.dtype) – the floating point dtype to make the default. Either torch.float32 or torch.float64.

Example

>>> # initial default for floating point is torch.float32
>>> # Python floats are interpreted as float32
>>> torch.tensor([1.2, 3]).dtype
torch.float32
>>> # initial default for floating point is torch.complex64
>>> # Complex Python numbers are interpreted as complex64
>>> torch.tensor([1.2, 3j]).dtype
torch.complex64
>>> torch.set_default_dtype(torch.float64)
>>> # Python floats are now interpreted as float64
>>> torch.tensor([1.2, 3]).dtype    # a new floating point tensor
torch.float64
>>> # Complex Python numbers are now interpreted as complex128
>>> torch.tensor([1.2, 3j]).dtype   # a new complex tensor
torch.complex128

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources