torch.set_default_dtype¶
- torch.set_default_dtype(d)[source]¶
Sets the default floating point dtype to
d
. Supports torch.float32 and torch.float64 as inputs. Other dtypes may be accepted without complaint but are not supported and are unlikely to work as expected.When PyTorch is initialized its default floating point dtype is torch.float32, and the intent of set_default_dtype(torch.float64) is to facilitate NumPy-like type inference. The default floating point dtype is used to:
Implicitly determine the default complex dtype. When the default floating point type is float32 the default complex dtype is complex64, and when the default floating point type is float64 the default complex type is complex128.
Infer the dtype for tensors constructed using Python floats or complex Python numbers. See examples below.
Determine the result of type promotion between bool and integer tensors and Python floats and complex Python numbers.
- Parameters:
d (
torch.dtype
) – the floating point dtype to make the default. Either torch.float32 or torch.float64.
Example
>>> # initial default for floating point is torch.float32 >>> # Python floats are interpreted as float32 >>> torch.tensor([1.2, 3]).dtype torch.float32 >>> # initial default for floating point is torch.complex64 >>> # Complex Python numbers are interpreted as complex64 >>> torch.tensor([1.2, 3j]).dtype torch.complex64
>>> torch.set_default_dtype(torch.float64)
>>> # Python floats are now interpreted as float64 >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor torch.float64 >>> # Complex Python numbers are now interpreted as complex128 >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor torch.complex128