Shortcuts

Source code for torchtune.config._instantiate

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
import os
import sys
from typing import Any, Callable, Dict, Tuple

from omegaconf import DictConfig, OmegaConf
from torchtune.config._errors import InstantiationError
from torchtune.config._utils import _get_component_from_path, _has_component


def _create_component(
    _component_: Callable[..., Any],
    args: Tuple[Any, ...],
    kwargs: Dict[str, Any],
) -> Any:
    return _component_(*args, **kwargs)


def _instantiate_node(node: Dict[str, Any], *args: Any) -> Any:
    """
    Creates the object specified in _component_ field with provided positional args
    and kwargs already merged. Raises an InstantiationError if _component_ is not specified.
    """
    if _has_component(node):
        _component_ = _get_component_from_path(node.get("_component_"))
        kwargs = {k: v for k, v in node.items() if k != "_component_"}
        return _create_component(_component_, args, kwargs)
    else:
        raise InstantiationError(
            "Cannot instantiate specified object."
            + "\nMake sure you've specified a _component_ field with a valid dotpath."
        )


[docs]def instantiate( config: DictConfig, *args: Any, **kwargs: Any, ) -> Any: """ Given a DictConfig with a _component_ field specifying the object to instantiate and additional fields for keyword arguments, create an instance of the specified object. You can use this function to create the exact instance of a torchtune object you want to use in your recipe using the specification from the config. This function also supports passing in positional args and keyword args within the function call. These are automatically merged with the provided config, with keyword args taking precedence. Based on Hydra's `instantiate` utility from Facebook Research: https://github.com/facebookresearch/hydra/blob/main/hydra/_internal/instantiate/_instantiate2.py#L148 Args: config (DictConfig): a single field in the OmegaConf object parsed from the yaml file. This is expected to have a _component_ field specifying the path of the object to instantiate. *args (Any): positional arguments to pass to the object to instantiate. **kwargs (Any): keyword arguments to pass to the object to instantiate. Examples: >>> config.yaml: >>> model: >>> _component_: torchtune.models.llama2 >>> num_layers: 32 >>> num_heads: 32 >>> num_kv_heads: 32 >>> from torchtune import config >>> vocab_size = 32000 >>> # Pass in vocab size as positional argument. Since it is positioned first >>> # in llama2(), it must be specified first. Pass in other arguments as kwargs. >>> # This will return an nn.Module directly for llama2 with specified args. >>> model = config.instantiate(parsed_yaml.model, vocab_size, max_seq_len=4096, embed_dim=4096) Returns: Any: the instantiated object. Raises: ValueError: if config is not a DictConfig. """ # Return None if config is None if config is None: return None if not OmegaConf.is_dict(config): raise ValueError(f"instantiate only supports DictConfigs, got {type(config)}") # Ensure local imports are able to be instantiated if os.getcwd() not in sys.path: sys.path.append(os.getcwd()) config_copy = copy.deepcopy(config) config_copy._set_flag( flags=["allow_objects", "struct", "readonly"], values=[True, False, False] ) config_copy._set_parent(config._get_parent()) config = config_copy if kwargs: # This overwrites any repeated fields in the config with kwargs config = OmegaConf.merge(config, kwargs) # Resolve all interpolations, or references to other fields within the same config OmegaConf.resolve(config) return _instantiate_node(OmegaConf.to_object(config), *args)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources