Source code for

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import inspect
from typing import Any, Callable, Dict, List, Mapping, Optional, Union

from torchx.specs.api import BindMount, MountType, VolumeMount
from torchx.specs.file_linter import get_fn_docstring, TorchXArgumentHelpFormatter
from torchx.util.types import (

from .api import AppDef, DeviceMount

def _create_args_parser(
    cmpnt_fn: Callable[..., AppDef], cmpnt_defaults: Optional[Dict[str, str]] = None
) -> argparse.ArgumentParser:
    parameters = inspect.signature(cmpnt_fn).parameters
    function_desc, args_desc = get_fn_docstring(cmpnt_fn)
    script_parser = argparse.ArgumentParser(
        prog=f"torchx run <run args...> {cmpnt_fn.__name__} ",
        # enables components to have "h" as a parameter
        # otherwise argparse by default adds -h/--help as the help argument
        # we still add --help but reserve "-"h" to be used as a component argument
    # add help manually since we disabled auto help to allow "h" in component arg
        help="show this help message and exit",

    class _reminder_action(argparse.Action):
        def __call__(
            parser: argparse.ArgumentParser,
            namespace: argparse.Namespace,
            values: Any,
            option_string: Optional[str] = None,
        ) -> None:
                (self.default or "").split() if len(values) == 0 else values,

    for param_name, parameter in parameters.items():
        param_desc = args_desc[]
        args: Dict[str, Any] = {
            "help": param_desc,
            "type": get_argparse_param_type(parameter),
        # set defaults specified in the component function declaration
        if parameter.default != inspect.Parameter.empty:
            if is_bool(type(parameter.default)):
                args["default"] = str(parameter.default)
                args["default"] = parameter.default

        # set defaults supplied directly to this method (overwrites the declared defaults)
        # the defaults are given as str (as option values passed from CLI) since
        # these are typically read from .torchxconfig
        if cmpnt_defaults and param_name in cmpnt_defaults:
            args["default"] = cmpnt_defaults[param_name]

        if parameter.kind == inspect._ParameterKind.VAR_POSITIONAL:
            args["nargs"] = argparse.REMAINDER
            args["action"] = _reminder_action
            script_parser.add_argument(param_name, **args)
            arg_names = [f"--{param_name}"]
            if len(param_name) == 1:
                arg_names = [f"-{param_name}"] + arg_names
            if "default" not in args:
                args["required"] = True
            script_parser.add_argument(*arg_names, **args)
    return script_parser

def materialize_appdef(
    cmpnt_fn: Callable[..., AppDef],
    cmpnt_args: List[str],
    cmpnt_defaults: Optional[Dict[str, str]] = None,
) -> AppDef:
    Creates an application by running user defined ``app_fn``.

    ``app_fn`` has the following restrictions:
        * Name must be ``app_fn``
        * All arguments should be annotated
        * Supported argument types:
            - primitive: int, str, float
            - Dict[primitive, primitive]
            - List[primitive]
            - Optional[Dict[primitive, primitive]]
            - Optional[List[primitive]]
        * ``app_fn`` can define a vararg (*arg) at the end
        * There should be a docstring for the function that defines
            All arguments in a google-style format
        * There can be default values for the function arguments.
        * The return object must be ``AppDef``

        cmpnt_fn: Component function
        cmpnt_args: Function args
        cmpnt_defaults: Additional default values for parameters of ``app_fn``
                          (overrides the defaults set on the fn declaration)
        An application spec

    script_parser = _create_args_parser(cmpnt_fn, cmpnt_defaults)
    parsed_args = script_parser.parse_args(cmpnt_args)

    function_args = []
    var_arg = []
    kwargs = {}

    parameters = inspect.signature(cmpnt_fn).parameters
    for param_name, parameter in parameters.items():
        arg_value = getattr(parsed_args, param_name)
        parameter_type = parameter.annotation
        parameter_type = decode_optional(parameter_type)
        if is_bool(parameter_type):
            arg_value = arg_value and arg_value.lower() == "true"
        elif not is_primitive(parameter_type):
            arg_value = decode_from_string(arg_value, parameter_type)
        if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
            var_arg = arg_value
        elif parameter.kind == inspect.Parameter.KEYWORD_ONLY:
            kwargs[param_name] = arg_value
        elif parameter.kind == inspect.Parameter.VAR_KEYWORD:
            raise TypeError("**kwargs are not supported for component definitions")
    if len(var_arg) > 0 and var_arg[0] == "--":
        var_arg = var_arg[1:]

    return cmpnt_fn(*function_args, *var_arg, **kwargs)

def make_app_handle(scheduler_backend: str, session_name: str, app_id: str) -> str:
    return f"{scheduler_backend}://{session_name}/{app_id}"

_MOUNT_OPT_MAP: Mapping[str, str] = {
    "type": "type",
    "destination": "dst",
    "dst": "dst",
    "target": "dst",
    "read_only": "readonly",
    "readonly": "readonly",
    "source": "src",
    "src": "src",
    "perm": "perm",

[docs]def parse_mounts(opts: List[str]) -> List[Union[BindMount, VolumeMount, DeviceMount]]: """ parse_mounts parses a list of options into typed mounts following a similar format to Dockers bind mount. Multiple mounts can be specified in the same list. ``type`` must be specified first in each. Ex: type=bind,src=/host,dst=/container,readonly,[type=bind,src=...,dst=...] Supported types: BindMount: type=bind,src=<host path>,dst=<container path>[,readonly] VolumeMount: type=volume,src=<name/id>,dst=<container path>[,readonly] DeviceMount: type=device,src=/dev/<dev>[,dst=<container path>][,perm=rwm] """ mount_opts = [] cur = {} for opt in opts: key, _, val = opt.partition("=") if key not in _MOUNT_OPT_MAP: raise KeyError( f"unknown mount option {key}, must be one of {list(_MOUNT_OPT_MAP.keys())}" ) key = _MOUNT_OPT_MAP[key] if key == "type": cur = {} mount_opts.append(cur) elif len(mount_opts) == 0: raise KeyError("type must be specified first") cur[key] = val mounts = [] for opts in mount_opts: typ = opts.get("type") if typ == MountType.BIND: mounts.append( BindMount( src_path=opts["src"], dst_path=opts["dst"], read_only="readonly" in opts, ) ) elif typ == MountType.VOLUME: mounts.append( VolumeMount( src=opts["src"], dst_path=opts["dst"], read_only="readonly" in opts ) ) elif typ == MountType.DEVICE: src = opts["src"] dst = opts.get("dst", src) perm = opts.get("perm", "rwm") for c in perm: if c not in "rwm": raise ValueError( f"{c} is not a valid permission flags must one of r,w,m" ) mounts.append(DeviceMount(src_path=src, dst_path=dst, permissions=perm)) else: valid = list(str(item.value) for item in MountType) raise ValueError(f"invalid mount type {repr(typ)}, must be one of {valid}") return mounts


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources