Shortcuts

HPO

Overview & Usage

The torchx.runtime.hpo module contains modules and functions that you can use to build a Hyperparameter Optimization (HPO) application. Note that an HPO application is the entity that is coordinating the HPO search and is not to be confused with the application that runs for each “trial” of the search. Typically a “trial” in an HPO job is the trainer app that trains an ML model given a set of parameters as dictated by the HPO job.

For grid-search, the HPO job may be as simple as a for-parallel loop that exhaustively runs through all the combinations of parameters in the user-defined search space. On the other hand, bayesian optimization requires the optimizer state to be preserved between trials, which leads a more non-trivial implementation of an HPO app.

Currently this module uses Ax as the underlying brains of HPO and offers a few extension points to integrate Ax with TorchX runners.

Quickstart Example

The following example demonstrates running an HPO job on a TorchX component. We use the builtin utils.booth component which simply runs an application that evaluates the booth function at (x1, x2). The objective is to find x1 and x2 that minimizes the booth function.

import os
from ax.core import (
   BatchTrial,
   Experiment,
   Objective,
   OptimizationConfig,
   Parameter,
   ParameterType,
   RangeParameter,
   SearchSpace,
)
from ax.modelbridge.dispatch_utils import choose_generation_strategy
from ax.service.scheduler import SchedulerOptions
from ax.service.utils.best_point import get_best_parameters
from ax.service.utils.report_utils import exp_to_df
from ax.utils.common.constants import Keys
from pyre_extensions import none_throws
from torchx.components import utils
from torchx.runtime.hpo.ax import AppMetric, TorchXRunner, TorchXScheduler

# Run HPO on the booth function (https://en.wikipedia.org/wiki/Test_functions_for_optimization)

parameters = [
   RangeParameter(
       name="x1",
       lower=-10.0,
       upper=10.0,
       parameter_type=ParameterType.FLOAT,
   ),
   RangeParameter(
       name="x2",
       lower=-10.0,
       upper=10.0,
       parameter_type=ParameterType.FLOAT,
   ),
]

objective = Objective(metric=AppMetric(name="booth_eval"), minimize=True)

runner = TorchXRunner(
   tracker_base=tmpdir,
   component=utils.booth,
   component_const_params={
       "image": "ghcr.io/pytorch/torchx:0.1.0rc0",
   },
   scheduler="local", # can also be [kubernetes, slurm, etc]
   scheduler_args={"log_dir": tmpdir, "image_type": "docker"},
)

experiment = Experiment(
    name="torchx_booth_sequential_demo",
    search_space=SearchSpace(parameters=parameters),
    optimization_config=OptimizationConfig(objective=objective),
    runner=runner,
    is_test=True,
    properties={Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF: True},
)

scheduler = TorchXScheduler(
    experiment=experiment,
    generation_strategy=(
        choose_generation_strategy(
            search_space=experiment.search_space,
        )
    ),
    options=SchedulerOptions(),
)


for i in range(3): 
   scheduler.run_n_trials(max_trials=2) 

print(exp_to_df(experiment)) 

Ax (Adaptive Experimentation)

class torchx.runtime.hpo.ax.TorchXRunner(tracker_base: str, component: Callable[[...], torchx.specs.api.AppDef], component_const_params: Optional[Dict[str, Any]] = None, scheduler: str = 'local', cfg: Optional[Mapping[str, Optional[Union[str, int, float, bool, List[str]]]]] = None)[source]

An implementation of ax.core.runner.Runner that delegates job submission to the TorchX Runner. This runner is coupled with the torchx component since Ax runners run trials of a single component with different parameters.

It is expected that the experiment parameter names and types match EXACTLY with component’s function args. Component function args that are NOT part of the search space can be passed as component_const_params. The following args (provided that the component function declares them in the function signature) as passed automatically:

  1. trial_idx (int): current trial’s index

  2. tracker_base (str): torchx tracker’s base (typically a URL indicating the base dir of the tracker)

Example:

def trainer_component(
   x1: int,
   x2: float,
   trial_idx: int,
   tracker_base: str,
   x3: float,
   x4: str) -> spec.AppDef:
   # ... implementation omitted for brevity ...
   pass

The experiment should be set up as:

parameters=[
  {
    "name": "x1",
    "value_type": "int",
    # ... other options...
  },
  {
    "name": "x2",
    "value_type": "float",
    # ... other options...
  }
]

And the rest of the arguments can be set as:

TorchXRunner(
   tracker_base="s3://foo/bar",
   component=trainer_component,
   # trial_idx and tracker_base args passed automatically
   # if the function signature declares those args
   component_const_params={"x3": 1.2, "x4": "barbaz"})

Running the experiment as set up above results in each trial running:

appdef = trainer_component(
           x1=trial.params["x1"],
           x2=trial.params["x2"],
           trial_idx=trial.index,
           tracker_base="s3://foo/bar",
           x3=1.2,
           x4="barbaz")

torchx.runner.get_runner().run(appdef, ...)
class torchx.runtime.hpo.ax.TorchXScheduler(experiment: ax.core.experiment.Experiment, generation_strategy: ax.modelbridge.generation_strategy.GenerationStrategy, options: ax.service.scheduler.SchedulerOptions, db_settings: Optional[ax.storage.sqa_store.structs.DBSettings] = None, _skip_experiment_save: bool = False)[source]

An implementation of an Ax Scheduler that works with Experiments hooked up with the TorchXRunner.

This scheduler is not a real scheduler but rather a facade scheduler that delegates to scheduler clients for various remote/local schedulers. For a list of supported schedulers please refer to TorchX scheduler docs.

class torchx.runtime.hpo.ax.AppMetric(name: str, lower_is_better: Optional[bool] = None, properties: Optional[Dict[str, Any]] = None)[source]

Fetches AppMetric (the observation returned by the trial job/app) via the torchx.tracking module. Assumes that the app used the tracker in the following manner:

tracker = torchx.runtime.tracking.FsspecResultTracker(tracker_base)
tracker[str(trial_index)] = {metric_name: value}

# -- or --
tracker[str(trial_index)] = {"metric_name/mean": mean_value,
                            "metric_name/sem": sem_value}

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources