get_primers_from_module¶
- class torchrl.modules.utils.get_primers_from_module(module)[source]¶
Get all tensordict primers from all submodules of a module.
This method is useful for retrieving primers from modules that are contained within a parent module.
- Parameters:
module (torch.nn.Module) – The parent module.
- Returns:
A TensorDictPrimer Transform.
- Return type:
Example
>>> from torchrl.modules.utils import get_primers_from_module >>> from torchrl.modules import GRUModule, MLP >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> # Define a GRU module >>> gru_module = GRUModule( ... input_size=10, ... hidden_size=10, ... num_layers=1, ... in_keys=["input", "recurrent_state", "is_init"], ... out_keys=["features", ("next", "recurrent_state")], ... ) >>> # Define a head module >>> head = TensorDictModule( ... MLP( ... in_features=10, ... out_features=10, ... num_cells=[], ... ), ... in_keys=["features"], ... out_keys=["output"], ... ) >>> # Create a sequential model >>> model = TensorDictSequential(gru_module, head) >>> # Retrieve primers from the model >>> primers = get_primers_from_module(model) >>> print(primers)
- TensorDictPrimer(primers=CompositeSpec(
- recurrent_state: UnboundedContinuousTensorSpec(
shape=torch.Size([1, 10]), space=None, device=cpu, dtype=torch.float32, domain=continuous), device=None, shape=torch.Size([])), default_value={‘recurrent_state’: 0.0}, random=None)