Source code for torch.distributed.tensor
# Copyright (c) Meta Platforms, Inc. and affiliates
import torch
import torch.distributed.tensor._ops # force import all built-in dtensor ops
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh # noqa: F401
from torch.distributed.tensor._api import (
distribute_module,
distribute_tensor,
DTensor,
empty,
full,
ones,
rand,
randn,
zeros,
)
from torch.distributed.tensor.placement_types import (
Partial,
Placement,
Replicate,
Shard,
)
from torch.optim.optimizer import (
_foreach_supported_types as _optim_foreach_supported_types,
)
from torch.utils._foreach_utils import (
_foreach_supported_types as _util_foreach_supported_types,
)
# All public APIs from dtensor package
__all__ = [
"DTensor",
"distribute_tensor",
"distribute_module",
"Shard",
"Replicate",
"Partial",
"Placement",
"ones",
"empty",
"full",
"rand",
"randn",
"zeros",
]
# Append DTensor to the list of supported types for foreach implementation for optimizer
# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA.
if DTensor not in _optim_foreach_supported_types:
_optim_foreach_supported_types.append(DTensor)
if DTensor not in _util_foreach_supported_types:
_util_foreach_supported_types.append(DTensor)
# Set namespace for exposed private names
DTensor.__module__ = "torch.distributed.tensor"
distribute_tensor.__module__ = "torch.distributed.tensor"
distribute_module.__module__ = "torch.distributed.tensor"
ones.__module__ = "torch.distributed.tensor"
empty.__module__ = "torch.distributed.tensor"
full.__module__ = "torch.distributed.tensor"
rand.__module__ = "torch.distributed.tensor"
randn.__module__ = "torch.distributed.tensor"
zeros.__module__ = "torch.distributed.tensor"