
Model Parallel

DistributedModelParallel is the main API for distributed training with TorchRec optimizations.

class torchrec.distributed.model_parallel.DistributedModelParallel(module: Module, env: Optional[ShardingEnv] = None, device: Optional[device] = None, plan: Optional[ShardingPlan] = None, sharders: Optional[List[ModuleSharder[Module]]] = None, init_data_parallel: bool = True, init_parameters: bool = True, data_parallel_wrapper: Optional[DataParallelWrapper] = None)

Entry point to model parallelism.

  • module (nn.Module) – module to wrap.

  • env (Optional[ShardingEnv]) – sharding environment that has the process group.

  • device (Optional[torch.device]) – compute device, defaults to cpu.

  • plan (Optional[ShardingPlan]) – plan to use when sharding, defaults to EmbeddingShardingPlanner.collective_plan().

  • sharders (Optional[List[ModuleSharder[nn.Module]]]) – ModuleSharders available to shard with, defaults to EmbeddingBagCollectionSharder().

  • init_data_parallel (bool) – data-parallel modules can be lazy, i.e. they delay parameter initialization until the first forward pass. Pass True to delay initialization of data parallel modules. Do first forward pass and then call DistributedModelParallel.init_data_parallel().

  • init_parameters (bool) – initialize parameters for modules still on meta device.

  • data_parallel_wrapper (Optional[DataParallelWrapper]) – custom wrapper for data parallel modules.


def init_weights(m):
    if isinstance(m, nn.Linear):
    elif isinstance(m, EmbeddingBagCollection):
        for param in m.parameters():

m = MyModel(device='meta')
m = DistributedModelParallel(m)
copy(device: device) DistributedModelParallel

Recursively copy submodules to new device by calling per-module customized copy process, since some modules needs to use the original references (like ShardedModule for inference).

forward(*args, **kwargs) Any

Define the computation performed at every call.

Should be overridden by all subclasses.


Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

init_data_parallel() None

See init_data_parallel c-tor argument for usage. It’s safe to call this method multiple times.

load_state_dict(state_dict: OrderedDict[str, Tensor], prefix: str = '', strict: bool = True) _IncompatibleKeys

Copy parameters and buffers from state_dict into this module and its descendants.

If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.


If assign is True the optimizer must be created after the call to load_state_dict unless get_swap_module_params_on_conversion() is True.

  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

  • assign (bool, optional) – When False, the properties of the tensors in the current module are preserved while when True, the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of Default: ``False`


  • missing_keys is a list of str containing any keys that are expected

    by this module but missing from the provided state_dict.

  • unexpected_keys is a list of str containing the keys that are not

    expected by this module but present in the provided state_dict.

Return type:

NamedTuple with missing_keys and unexpected_keys fields


If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

property module: Module

Property to directly access sharded module, which will not be wrapped in DDP, FSDP, DMP, or any other parallelism wrappers.

named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]]

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

  • prefix (str) – prefix to prepend to all buffer names.

  • recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.

  • remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.


(str, torch.Tensor) – Tuple containing the name and buffer


>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]]

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

  • prefix (str) – prefix to prepend to all parameter names.

  • recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.

  • remove_duplicate (bool, optional) – whether to remove the duplicated parameters in the result. Defaults to True.


(str, Parameter) – Tuple containing the name and parameter


>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any]

Return a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.


The returned object is a shallow copy. It contains references to the module’s parameters and buffers.


Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.


Please avoid the use of argument destination as it is not designed for end-users.

  • destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''.

  • keep_vars (bool, optional) – by default the Tensor s returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.


a dictionary containing a whole state of the module

Return type:



>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']


