Shortcuts

Multi-Objective NAS with Ax

Authors: David Eriksson, Max Balandat, and the Adaptive Experimentation team at Meta.

In this tutorial, we show how to use Ax to run multi-objective neural architecture search (NAS) for a simple neural network model on the popular MNIST dataset. While the underlying methodology would typically be used for more complicated models and larger datasets, we opt for a tutorial that is easily runnable end-to-end on a laptop in less than 20 minutes.

In many NAS applications, there is a natural tradeoff between multiple objectives of interest. For instance, when deploying models on-device we may want to maximize model performance (for example, accuracy), while simultaneously minimizing competing metrics like power consumption, inference latency, or model size in order to satisfy deployment constraints. Often, we may be able to reduce computational requirements or latency of predictions substantially by accepting minimally lower model performance. Principled methods for exploring such tradeoffs efficiently are key enablers of scalable and sustainable AI, and have many successful applications at Meta - see for instance our case study on a Natural Language Understanding model.

In our example here, we will tune the widths of two hidden layers, the learning rate, the dropout probability, the batch size, and the number of training epochs. The goal is to trade off performance (accuracy on the validation set) and model size (the number of model parameters).

This tutorial makes use of the following PyTorch libraries:

  • PyTorch Lightning (specifying the model and training loop)

  • TorchX (for running training jobs remotely / asynchronously)

  • BoTorch (the Bayesian Optimization library powering Ax’s algorithms)

Defining the TorchX App

Our goal is to optimize the PyTorch Lightning training job defined in mnist_train_nas.py. To do this using TorchX, we write a helper function that takes in the values of the architcture and hyperparameters of the training job and creates a TorchX AppDef with the appropriate settings.

from pathlib import Path

import torchx

from torchx import specs
from torchx.components import utils


def trainer(
    log_path: str,
    hidden_size_1: int,
    hidden_size_2: int,
    learning_rate: float,
    epochs: int,
    dropout: float,
    batch_size: int,
    trial_idx: int = -1,
) -> specs.AppDef:

    # define the log path so we can pass it to the TorchX AppDef
    if trial_idx >= 0:
        log_path = Path(log_path).joinpath(str(trial_idx)).absolute().as_posix()

    return utils.python(
        # command line args to the training script
        "--log_path",
        log_path,
        "--hidden_size_1",
        str(hidden_size_1),
        "--hidden_size_2",
        str(hidden_size_2),
        "--learning_rate",
        str(learning_rate),
        "--epochs",
        str(epochs),
        "--dropout",
        str(dropout),
        "--batch_size",
        str(batch_size),
        # other config options
        name="trainer",
        script="mnist_train_nas.py",
        image=torchx.version.TORCHX_IMAGE,
    )

Setting up the Runner

Ax’s Runner abstraction allows writing interfaces to various backends. Ax already comes with Runner for TorchX, and so we just need to configure it. For the purpose of this tutorial we run jobs locally in a fully asynchronous fashion.

In order to launch them on a cluster, you can instead specify a different TorchX scheduler and adjust the configuration appropriately. For example, if you have a Kubernetes cluster, you just need to change the scheduler from local_cwd to kubernetes).

import tempfile
from ax.runners.torchx import TorchXRunner

# Make a temporary dir to log our results into
log_dir = tempfile.mkdtemp()

ax_runner = TorchXRunner(
    tracker_base="/tmp/",
    component=trainer,
    # NOTE: To launch this job on a cluster instead of locally you can
    # specify a different scheduler and adjust args appropriately.
    scheduler="local_cwd",
    component_const_params={"log_path": log_dir},
    cfg={},
)
/opt/conda/lib/python3.7/site-packages/sklearn/utils/validation.py:37: DeprecationWarning:

distutils Version classes are deprecated. Use packaging.version instead.

Setting up the SearchSpace

First, we define our search space. Ax supports both range parameters of type integer and float as well as choice parameters which can have non-numerical types such as strings. We will tune the hidden sizes, learning rate, dropout, and number of epochs as range parameters and tune the batch size as an ordered choice parameter to enforce it to be a power of 2.

from ax.core import (
    ChoiceParameter,
    ParameterType,
    RangeParameter,
    SearchSpace,
)

parameters = [
    # NOTE: In a real-world setting, hidden_size_1 and hidden_size_2
    # should probably be powers of 2, but in our simple example this
    # would mean that num_params can't take on that many values, which
    # in turn makes the Pareto frontier look pretty weird.
    RangeParameter(
        name="hidden_size_1",
        lower=16,
        upper=128,
        parameter_type=ParameterType.INT,
        log_scale=True,
    ),
    RangeParameter(
        name="hidden_size_2",
        lower=16,
        upper=128,
        parameter_type=ParameterType.INT,
        log_scale=True,
    ),
    RangeParameter(
        name="learning_rate",
        lower=1e-4,
        upper=1e-2,
        parameter_type=ParameterType.FLOAT,
        log_scale=True,
    ),
    RangeParameter(
        name="epochs",
        lower=1,
        upper=4,
        parameter_type=ParameterType.INT,
    ),
    RangeParameter(
        name="dropout",
        lower=0.0,
        upper=0.5,
        parameter_type=ParameterType.FLOAT,
    ),
    ChoiceParameter(  # NOTE: ChoiceParameters don't require log-scale
        name="batch_size",
        values=[32, 64, 128, 256],
        parameter_type=ParameterType.INT,
        is_ordered=True,
        sort_values=True,
    ),
]

search_space = SearchSpace(
    parameters=parameters,
    # NOTE: In practice, it may make sense to add a constraint
    # hidden_size_2 <= hidden_size_1
    parameter_constraints=[],
)

Setting up Metrics

Ax has the concept of a Metric that defines properties of outcomes and how observations are obtained for these outcomes. This allows e.g. encodig how data is fetched from some distributed execution backend and post-processed before being passed as input to Ax.

In this tutorial we will use multi-objective optimization with the goal of maximizing the validation accuracy and minimizing the number of model parameters. The latter represents a simple proxy of model latency, which is hard to estimate accurately for small ML models (in an actual application we would benchmark the latency while running the model on-device).

In our example TorchX will run the training jobs in a fully asynchronous fashion locally and write the results to the log_dir based on the trial index (see the trainer() function above). We will define a metric class that is aware of that logging directory. By subclassing TensorboardCurveMetric we get the logic to read and parse the Tensorboard logs for free.

from ax.metrics.tensorboard import TensorboardCurveMetric


class MyTensorboardMetric(TensorboardCurveMetric):

    # NOTE: We need to tell the new Tensorboard metric how to get the id /
    # file handle for the tensorboard logs from a trial. In this case
    # our convention is to just save a separate file per trial in
    # the pre-specified log dir.
    @classmethod
    def get_ids_from_trials(cls, trials):
        return {
            trial.index: Path(log_dir).joinpath(str(trial.index)).as_posix()
            for trial in trials
        }

    # This indicates whether the metric is queryable while the trial is
    # still running. We don't use this in the current tutorial, but Ax
    # utilizes this to implement trial-level early-stopping functionality.
    @classmethod
    def is_available_while_running(cls):
        return False

Now we can instatiate the metrics for accuracy and the number of model parameters. Here curve_name is the name of the metric in the Tensorboard logs, while name is the metric name used internally by Ax. We also specify lower_is_better to indicate the favorable direction of the two metrics.

val_acc = MyTensorboardMetric(
    name="val_acc",
    curve_name="val_acc",
    lower_is_better=False,
)
model_num_params = MyTensorboardMetric(
    name="num_params",
    curve_name="num_params",
    lower_is_better=True,
)

Setting up the OptimizationConfig

The way to tell Ax what it should optimize is by means of an OptimizationConfig. Here we use a MultiObjectiveOptimizationConfig as we will be performing multi-objective optimization.

Additionally, Ax supports placing constraints on the different metrics by specifying objective thresholds, which bound the region of interest in the outcome space that we want to explore. For this example, we will constrain the validation accuracy to be at least 0.94 (94%) and the number of model parameters to be at most 80,000.

from ax.core import MultiObjective, Objective, ObjectiveThreshold
from ax.core.optimization_config import MultiObjectiveOptimizationConfig


opt_config = MultiObjectiveOptimizationConfig(
    objective=MultiObjective(
        objectives=[
            Objective(metric=val_acc, minimize=False),
            Objective(metric=model_num_params, minimize=True),
        ],
    ),
    objective_thresholds=[
        ObjectiveThreshold(metric=val_acc, bound=0.94, relative=False),
        ObjectiveThreshold(metric=model_num_params, bound=80_000, relative=False),
    ],
)

Creating the Ax Experiment

In Ax, the Experiment object is the object that stores all the information about the problem setup.

from ax.core import Experiment

experiment = Experiment(
    name="torchx_mnist",
    search_space=search_space,
    optimization_config=opt_config,
    runner=ax_runner,
)

Choosing the GenerationStrategy

A GenerationStrategy is the abstract representation of how we would like to perform the optimization. While this can be customized (if you’d like to do so, see this tutorial), in most cases Ax can automatically determine an appropriate strategy based on the search space, optimization config, and the total number of trials we want to run.

Typically, Ax chooses to evaluate a number of random configurations before starting a model-based Bayesian Optimization strategy.

total_trials = 48  # total evaluation budget

from ax.modelbridge.dispatch_utils import choose_generation_strategy

gs = choose_generation_strategy(
    search_space=experiment.search_space,
    optimization_config=experiment.optimization_config,
    num_trials=total_trials,
  )
[INFO 09-28 15:25:08] ax.modelbridge.dispatch_utils: Using Bayesian optimization since there are more ordered parameters than there are categories for the unordered categorical parameters.
[INFO 09-28 15:25:08] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+MOO', steps=[Sobol for 9 trials, MOO for subsequent trials]). Iterations after 9 will take longer to generate due to  model-fitting.

Configuring the Scheduler

The Scheduler (TODO: link) acts as the loop control for the optimization. It communicates with the backend to launch trials, check their status, and retrieve results. In the case of this tutorial, it is simply reading and parsing the locally saved logs. In a remote execution setting, it would call APIs. The following illustration from the Ax Scheduler tutorial summarizes how the Scheduler interacts with external systems used to run trial evaluations:

../_static/img/ax_scheduler_illustration.png

The Scheduler requires the Experiment and the GenerationStrategy. A set of options can be passed in via SchedulerOptions. Here, we configure the number of total evaluations as well as max_pending_trials, the maximum number of trials that should run concurrently. In our local setting, this is the number of training jobs running as individual processes, while in a remote execution setting, this would be the number of machines you want to use in parallel.

from ax.service.scheduler import Scheduler, SchedulerOptions

scheduler = Scheduler(
    experiment=experiment,
    generation_strategy=gs,
    options=SchedulerOptions(
        total_trials=total_trials, max_pending_trials=4
    ),
)
[INFO 09-28 15:25:08] Scheduler: `Scheduler` requires experiment to have immutable search space and optimization config. Setting property immutable_search_space_and_opt_config to `True` on experiment.

Running the optimization

Now that everything is configured, we can let Ax run the optimization in a fully automated fashion. The Scheduler will periodially check the logs for the status of all currently running trials, and if a trial completes the scheduler will update its status on the experiment and fetch the observations needed for the Bayesian optimization algorithm.

scheduler.run_all_trials()
[INFO 09-28 15:25:08] Scheduler: Running trials [0]...
[INFO 09-28 15:25:09] Scheduler: Running trials [1]...
[INFO 09-28 15:25:10] Scheduler: Running trials [2]...
[INFO 09-28 15:25:11] Scheduler: Running trials [3]...
[INFO 09-28 15:25:12] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 4).
[INFO 09-28 15:25:14] Scheduler: Retrieved FAILED trials: [0].
[INFO 09-28 15:25:14] Scheduler: Running trials [4]...
[INFO 09-28 15:25:15] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 4).
[INFO 09-28 15:25:16] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 4).
[INFO 09-28 15:25:17] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 4).
[INFO 09-28 15:25:19] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 4).
[INFO 09-28 15:25:23] Scheduler: Waiting for completed trials (for 5 sec, currently running trials: 4).
[INFO 09-28 15:25:28] Scheduler: Waiting for completed trials (for 7 sec, currently running trials: 4).
[INFO 09-28 15:25:35] Scheduler: Waiting for completed trials (for 11 sec, currently running trials: 4).
[INFO 09-28 15:25:47] Scheduler: Waiting for completed trials (for 17 sec, currently running trials: 4).
[INFO 09-28 15:26:04] Scheduler: Retrieved COMPLETED trials: [2].
[INFO 09-28 15:26:04] Scheduler: Fetching data for trials: [2].
[INFO 09-28 15:26:04] Scheduler: Running trials [5]...
[INFO 09-28 15:26:05] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 4).
[INFO 09-28 15:26:06] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 4).
[INFO 09-28 15:26:08] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 4).
[INFO 09-28 15:26:10] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 4).
[INFO 09-28 15:26:13] Scheduler: Waiting for completed trials (for 5 sec, currently running trials: 4).
[INFO 09-28 15:26:18] Scheduler: Waiting for completed trials (for 7 sec, currently running trials: 4).
[INFO 09-28 15:26:26] Scheduler: Waiting for completed trials (for 11 sec, currently running trials: 4).
[INFO 09-28 15:26:37] Scheduler: Waiting for completed trials (for 17 sec, currently running trials: 4).
[INFO 09-28 15:26:55] Scheduler: Retrieved COMPLETED trials: 4 - 5.
[INFO 09-28 15:26:55] Scheduler: Fetching data for trials: 4 - 5.
[INFO 09-28 15:26:55] Scheduler: Running trials [6]...
[INFO 09-28 15:26:55] Scheduler: Running trials [7]...
[INFO 09-28 15:26:55] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 4).
[INFO 09-28 15:26:56] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 4).
[INFO 09-28 15:26:58] Scheduler: Retrieved COMPLETED trials: [1].
[INFO 09-28 15:26:58] Scheduler: Fetching data for trials: [1].
[INFO 09-28 15:26:58] Scheduler: Running trials [8]...
[INFO 09-28 15:26:59] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 4).
[INFO 09-28 15:27:00] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 4).
[INFO 09-28 15:27:01] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 4).
[INFO 09-28 15:27:03] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 4).
[INFO 09-28 15:27:07] Scheduler: Waiting for completed trials (for 5 sec, currently running trials: 4).
[INFO 09-28 15:27:12] Scheduler: Waiting for completed trials (for 7 sec, currently running trials: 4).
[INFO 09-28 15:27:20] Scheduler: Retrieved COMPLETED trials: [3].
[INFO 09-28 15:27:20] Scheduler: Fetching data for trials: [3].
[INFO 09-28 15:27:20] Scheduler: Running trials [9]...
[INFO 09-28 15:27:21] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 4).
[INFO 09-28 15:27:22] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 4).
[INFO 09-28 15:27:23] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 4).
[INFO 09-28 15:27:26] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 4).
[INFO 09-28 15:27:29] Scheduler: Waiting for completed trials (for 5 sec, currently running trials: 4).
[INFO 09-28 15:27:34] Scheduler: Retrieved COMPLETED trials: [7].
[INFO 09-28 15:27:34] Scheduler: Fetching data for trials: [7].
/opt/conda/lib/python3.7/site-packages/gpytorch/lazy/lazy_tensor.py:1811: UserWarning:

torch.triangular_solve is deprecated in favor of torch.linalg.solve_triangularand will be removed in a future PyTorch release.
torch.linalg.solve_triangular has its arguments reversed and does not return a copy of one of the inputs.
X = torch.triangular_solve(B, A).solution
should be replaced with
X = torch.linalg.solve_triangular(A, B). (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2189.)

[INFO 09-28 15:27:36] Scheduler: Running trials [10]...
[INFO 09-28 15:27:37] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 4).
[INFO 09-28 15:27:38] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 4).
[INFO 09-28 15:27:39] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 4).
[INFO 09-28 15:27:42] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 4).
[INFO 09-28 15:27:45] Scheduler: Waiting for completed trials (for 5 sec, currently running trials: 4).
[INFO 09-28 15:27:50] Scheduler: Waiting for completed trials (for 7 sec, currently running trials: 4).
[INFO 09-28 15:27:58] Scheduler: Waiting for completed trials (for 11 sec, currently running trials: 4).
[INFO 09-28 15:28:09] Scheduler: Waiting for completed trials (for 17 sec, currently running trials: 4).
[INFO 09-28 15:28:26] Scheduler: Retrieved COMPLETED trials: [6, 8].
[INFO 09-28 15:28:26] Scheduler: Fetching data for trials: [6, 8].
[INFO 09-28 15:28:28] Scheduler: Running trials [11]...
[INFO 09-28 15:28:32] Scheduler: Running trials [12]...
[INFO 09-28 15:28:33] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 4).
[INFO 09-28 15:28:34] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 4).
[INFO 09-28 15:28:36] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 4).
[INFO 09-28 15:28:38] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 4).
[INFO 09-28 15:28:41] Scheduler: Waiting for completed trials (for 5 sec, currently running trials: 4).
[INFO 09-28 15:28:46] Scheduler: Waiting for completed trials (for 7 sec, currently running trials: 4).
[INFO 09-28 15:28:54] Scheduler: Waiting for completed trials (for 11 sec, currently running trials: 4).
[INFO 09-28 15:29:05] Scheduler: Waiting for completed trials (for 17 sec, currently running trials: 4).
[INFO 09-28 15:29:22] Scheduler: Waiting for completed trials (for 25 sec, currently running trials: 4).
[INFO 09-28 15:29:48] Scheduler: Retrieved COMPLETED trials: 9 - 10.
[INFO 09-28 15:29:48] Scheduler: Fetching data for trials: 9 - 10.
[INFO 09-28 15:29:50] Scheduler: Running trials [13]...
[INFO 09-28 15:29:52] Scheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 09-28 15:29:52] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 09-28 15:29:53] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 09-28 15:29:55] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 09-28 15:29:57] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 3).
[INFO 09-28 15:30:00] Scheduler: Retrieved COMPLETED trials: 11 - 12.
[INFO 09-28 15:30:00] Scheduler: Fetching data for trials: 11 - 12.
[INFO 09-28 15:30:03] Scheduler: Running trials [14]...
[INFO 09-28 15:30:07] Scheduler: Running trials [15]...
[INFO 09-28 15:30:08] Scheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 09-28 15:30:08] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 09-28 15:30:09] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 09-28 15:30:10] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 09-28 15:30:12] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 3).
[INFO 09-28 15:30:16] Scheduler: Waiting for completed trials (for 5 sec, currently running trials: 3).
[INFO 09-28 15:30:21] Scheduler: Waiting for completed trials (for 7 sec, currently running trials: 3).
[INFO 09-28 15:30:29] Scheduler: Waiting for completed trials (for 11 sec, currently running trials: 3).
[INFO 09-28 15:30:40] Scheduler: Waiting for completed trials (for 17 sec, currently running trials: 3).
[INFO 09-28 15:30:57] Scheduler: Waiting for completed trials (for 25 sec, currently running trials: 3).
[INFO 09-28 15:31:23] Scheduler: Retrieved COMPLETED trials: [13].
[INFO 09-28 15:31:23] Scheduler: Fetching data for trials: [13].
[INFO 09-28 15:31:26] Scheduler: Running trials [16]...
[INFO 09-28 15:31:28] Scheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 09-28 15:31:28] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 09-28 15:31:29] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 09-28 15:31:30] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 09-28 15:31:33] Scheduler: Retrieved COMPLETED trials: [15].
[INFO 09-28 15:31:33] Scheduler: Fetching data for trials: [15].
[INFO 09-28 15:31:35] Scheduler: Running trials [17]...
[INFO 09-28 15:31:36] Scheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 09-28 15:31:37] Scheduler: Retrieved COMPLETED trials: [14].
[INFO 09-28 15:31:37] Scheduler: Fetching data for trials: [14].
[INFO 09-28 15:31:39] Scheduler: Running trials [18]...
[INFO 09-28 15:31:41] Scheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 09-28 15:31:41] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 09-28 15:31:43] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 09-28 15:31:44] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 09-28 15:31:46] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 3).
[INFO 09-28 15:31:50] Scheduler: Waiting for completed trials (for 5 sec, currently running trials: 3).
[INFO 09-28 15:31:55] Scheduler: Waiting for completed trials (for 7 sec, currently running trials: 3).
[INFO 09-28 15:32:02] Scheduler: Waiting for completed trials (for 11 sec, currently running trials: 3).
[INFO 09-28 15:32:14] Scheduler: Waiting for completed trials (for 17 sec, currently running trials: 3).
[INFO 09-28 15:32:31] Scheduler: Waiting for completed trials (for 25 sec, currently running trials: 3).
[INFO 09-28 15:32:56] Scheduler: Waiting for completed trials (for 38 sec, currently running trials: 3).
[INFO 09-28 15:33:35] Scheduler: Retrieved COMPLETED trials: 16 - 18.
[INFO 09-28 15:33:35] Scheduler: Fetching data for trials: 16 - 18.
[INFO 09-28 15:33:38] Scheduler: Running trials [19]...
[INFO 09-28 15:33:42] Scheduler: Running trials [20]...
[INFO 09-28 15:33:56] Scheduler: Running trials [21]...
[INFO 09-28 15:33:58] Scheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 09-28 15:33:58] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 09-28 15:33:59] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 09-28 15:34:00] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 09-28 15:34:03] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 3).
[INFO 09-28 15:34:06] Scheduler: Waiting for completed trials (for 5 sec, currently running trials: 3).
[INFO 09-28 15:34:11] Scheduler: Waiting for completed trials (for 7 sec, currently running trials: 3).
[INFO 09-28 15:34:19] Scheduler: Waiting for completed trials (for 11 sec, currently running trials: 3).
[INFO 09-28 15:34:30] Scheduler: Waiting for completed trials (for 17 sec, currently running trials: 3).
[INFO 09-28 15:34:47] Scheduler: Waiting for completed trials (for 25 sec, currently running trials: 3).
[INFO 09-28 15:35:13] Scheduler: Waiting for completed trials (for 38 sec, currently running trials: 3).
[INFO 09-28 15:35:51] Scheduler: Retrieved COMPLETED trials: 19 - 21.
[INFO 09-28 15:35:51] Scheduler: Fetching data for trials: 19 - 21.
[INFO 09-28 15:35:55] Scheduler: Running trials [22]...
[INFO 09-28 15:36:01] Scheduler: Running trials [23]...
[INFO 09-28 15:36:06] Scheduler: Running trials [24]...
[INFO 09-28 15:36:08] Scheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 09-28 15:36:08] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 09-28 15:36:09] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 09-28 15:36:11] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 09-28 15:36:13] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 3).
[INFO 09-28 15:36:16] Scheduler: Waiting for completed trials (for 5 sec, currently running trials: 3).
[INFO 09-28 15:36:22] Scheduler: Waiting for completed trials (for 7 sec, currently running trials: 3).
[INFO 09-28 15:36:29] Scheduler: Waiting for completed trials (for 11 sec, currently running trials: 3).
[INFO 09-28 15:36:41] Scheduler: Waiting for completed trials (for 17 sec, currently running trials: 3).
[INFO 09-28 15:36:58] Scheduler: Waiting for completed trials (for 25 sec, currently running trials: 3).
[INFO 09-28 15:37:23] Scheduler: Waiting for completed trials (for 38 sec, currently running trials: 3).
[INFO 09-28 15:38:02] Scheduler: Retrieved COMPLETED trials: 22 - 24.
[INFO 09-28 15:38:02] Scheduler: Fetching data for trials: 22 - 24.
[INFO 09-28 15:38:04] Scheduler: Running trials [25]...
[INFO 09-28 15:38:08] Scheduler: Running trials [26]...
[INFO 09-28 15:38:13] Scheduler: Running trials [27]...
[INFO 09-28 15:38:15] Scheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 09-28 15:38:15] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 09-28 15:38:16] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 09-28 15:38:17] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 09-28 15:38:20] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 3).
[INFO 09-28 15:38:23] Scheduler: Waiting for completed trials (for 5 sec, currently running trials: 3).
[INFO 09-28 15:38:28] Scheduler: Waiting for completed trials (for 7 sec, currently running trials: 3).
[INFO 09-28 15:38:36] Scheduler: Waiting for completed trials (for 11 sec, currently running trials: 3).
[INFO 09-28 15:38:47] Scheduler: Waiting for completed trials (for 17 sec, currently running trials: 3).
[INFO 09-28 15:39:04] Scheduler: Waiting for completed trials (for 25 sec, currently running trials: 3).
[INFO 09-28 15:39:30] Scheduler: Retrieved COMPLETED trials: [25].
[INFO 09-28 15:39:30] Scheduler: Fetching data for trials: [25].
[INFO 09-28 15:39:33] Scheduler: Running trials [28]...
[INFO 09-28 15:39:35] Scheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 09-28 15:39:35] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 09-28 15:39:36] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 09-28 15:39:37] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 09-28 15:39:40] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 3).
[INFO 09-28 15:39:43] Scheduler: Waiting for completed trials (for 5 sec, currently running trials: 3).
[INFO 09-28 15:39:48] Scheduler: Waiting for completed trials (for 7 sec, currently running trials: 3).
[INFO 09-28 15:39:56] Scheduler: Retrieved COMPLETED trials: 26 - 27.
[INFO 09-28 15:39:56] Scheduler: Fetching data for trials: 26 - 27.
[INFO 09-28 15:39:58] Scheduler: Running trials [29]...
[INFO 09-28 15:40:05] Scheduler: Running trials [30]...
[INFO 09-28 15:40:07] Scheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 09-28 15:40:07] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 09-28 15:40:08] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 09-28 15:40:10] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 09-28 15:40:12] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 3).
[INFO 09-28 15:40:16] Scheduler: Waiting for completed trials (for 5 sec, currently running trials: 3).
[INFO 09-28 15:40:21] Scheduler: Waiting for completed trials (for 7 sec, currently running trials: 3).
[INFO 09-28 15:40:28] Scheduler: Waiting for completed trials (for 11 sec, currently running trials: 3).
[INFO 09-28 15:40:40] Scheduler: Waiting for completed trials (for 17 sec, currently running trials: 3).
[INFO 09-28 15:40:57] Scheduler: Waiting for completed trials (for 25 sec, currently running trials: 3).
[INFO 09-28 15:41:23] Scheduler: Retrieved COMPLETED trials: [28].
[INFO 09-28 15:41:23] Scheduler: Fetching data for trials: [28].
[INFO 09-28 15:41:26] Scheduler: Running trials [31]...
[INFO 09-28 15:41:28] Scheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 09-28 15:41:28] Scheduler: Retrieved COMPLETED trials: [30].
[INFO 09-28 15:41:28] Scheduler: Fetching data for trials: [30].
[INFO 09-28 15:41:32] Scheduler: Running trials [32]...
[INFO 09-28 15:41:33] Scheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 09-28 15:41:33] Scheduler: Retrieved COMPLETED trials: [29].
[INFO 09-28 15:41:33] Scheduler: Fetching data for trials: [29].
[INFO 09-28 15:41:36] Scheduler: Running trials [33]...
[INFO 09-28 15:41:38] Scheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 09-28 15:41:38] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 09-28 15:41:39] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 09-28 15:41:41] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 09-28 15:41:43] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 3).
[INFO 09-28 15:41:46] Scheduler: Waiting for completed trials (for 5 sec, currently running trials: 3).
[INFO 09-28 15:41:51] Scheduler: Waiting for completed trials (for 7 sec, currently running trials: 3).
[INFO 09-28 15:41:59] Scheduler: Waiting for completed trials (for 11 sec, currently running trials: 3).
[INFO 09-28 15:42:10] Scheduler: Waiting for completed trials (for 17 sec, currently running trials: 3).
[INFO 09-28 15:42:27] Scheduler: Waiting for completed trials (for 25 sec, currently running trials: 3).
[INFO 09-28 15:42:53] Scheduler: Waiting for completed trials (for 38 sec, currently running trials: 3).
[INFO 09-28 15:43:32] Scheduler: Retrieved COMPLETED trials: 31 - 33.
[INFO 09-28 15:43:32] Scheduler: Fetching data for trials: 31 - 33.
[INFO 09-28 15:43:34] Scheduler: Running trials [34]...
[INFO 09-28 15:43:39] Scheduler: Running trials [35]...
[INFO 09-28 15:43:43] Scheduler: Running trials [36]...
[INFO 09-28 15:43:45] Scheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 09-28 15:43:45] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 09-28 15:43:46] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 09-28 15:43:48] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 09-28 15:43:50] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 3).
[INFO 09-28 15:43:53] Scheduler: Waiting for completed trials (for 5 sec, currently running trials: 3).
[INFO 09-28 15:43:58] Scheduler: Waiting for completed trials (for 7 sec, currently running trials: 3).
[INFO 09-28 15:44:06] Scheduler: Waiting for completed trials (for 11 sec, currently running trials: 3).
[INFO 09-28 15:44:17] Scheduler: Waiting for completed trials (for 17 sec, currently running trials: 3).
[INFO 09-28 15:44:34] Scheduler: Waiting for completed trials (for 25 sec, currently running trials: 3).
[INFO 09-28 15:45:00] Scheduler: Waiting for completed trials (for 38 sec, currently running trials: 3).
[INFO 09-28 15:45:39] Scheduler: Retrieved COMPLETED trials: 34 - 35.
[INFO 09-28 15:45:39] Scheduler: Fetching data for trials: 34 - 35.
[INFO 09-28 15:45:41] Scheduler: Running trials [37]...
[INFO 09-28 15:45:46] Scheduler: Running trials [38]...
[INFO 09-28 15:45:48] Scheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 09-28 15:45:48] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 09-28 15:45:49] Scheduler: Retrieved COMPLETED trials: [36].
[INFO 09-28 15:45:49] Scheduler: Fetching data for trials: [36].
[INFO 09-28 15:45:55] Scheduler: Running trials [39]...
[INFO 09-28 15:45:57] Scheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 09-28 15:45:57] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 09-28 15:45:58] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 09-28 15:45:59] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 09-28 15:46:01] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 3).
[INFO 09-28 15:46:05] Scheduler: Waiting for completed trials (for 5 sec, currently running trials: 3).
[INFO 09-28 15:46:10] Scheduler: Waiting for completed trials (for 7 sec, currently running trials: 3).
[INFO 09-28 15:46:17] Scheduler: Waiting for completed trials (for 11 sec, currently running trials: 3).
[INFO 09-28 15:46:29] Scheduler: Waiting for completed trials (for 17 sec, currently running trials: 3).
[INFO 09-28 15:46:46] Scheduler: Waiting for completed trials (for 25 sec, currently running trials: 3).
[INFO 09-28 15:47:12] Scheduler: Waiting for completed trials (for 38 sec, currently running trials: 3).
[INFO 09-28 15:47:50] Scheduler: Retrieved COMPLETED trials: 37 - 39.
[INFO 09-28 15:47:50] Scheduler: Fetching data for trials: 37 - 39.
[INFO 09-28 15:47:56] Scheduler: Running trials [40]...
[INFO 09-28 15:48:00] Scheduler: Running trials [41]...
[INFO 09-28 15:48:02] Scheduler: Running trials [42]...
[INFO 09-28 15:48:04] Scheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 09-28 15:48:04] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 09-28 15:48:05] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 09-28 15:48:06] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 09-28 15:48:08] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 3).
[INFO 09-28 15:48:12] Scheduler: Waiting for completed trials (for 5 sec, currently running trials: 3).
[INFO 09-28 15:48:17] Scheduler: Waiting for completed trials (for 7 sec, currently running trials: 3).
[INFO 09-28 15:48:24] Scheduler: Waiting for completed trials (for 11 sec, currently running trials: 3).
[INFO 09-28 15:48:36] Scheduler: Waiting for completed trials (for 17 sec, currently running trials: 3).
[INFO 09-28 15:48:53] Scheduler: Waiting for completed trials (for 25 sec, currently running trials: 3).
[INFO 09-28 15:49:19] Scheduler: Waiting for completed trials (for 38 sec, currently running trials: 3).
[INFO 09-28 15:49:57] Scheduler: Retrieved COMPLETED trials: 40 - 42.
[INFO 09-28 15:49:57] Scheduler: Fetching data for trials: 40 - 42.
[INFO 09-28 15:49:59] Scheduler: Running trials [43]...
[INFO 09-28 15:50:05] Scheduler: Running trials [44]...
[INFO 09-28 15:50:14] Scheduler: Running trials [45]...
[INFO 09-28 15:50:15] Scheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 09-28 15:50:15] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 09-28 15:50:16] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 09-28 15:50:18] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 09-28 15:50:20] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 3).
[INFO 09-28 15:50:23] Scheduler: Waiting for completed trials (for 5 sec, currently running trials: 3).
[INFO 09-28 15:50:28] Scheduler: Waiting for completed trials (for 7 sec, currently running trials: 3).
[INFO 09-28 15:50:36] Scheduler: Waiting for completed trials (for 11 sec, currently running trials: 3).
[INFO 09-28 15:50:47] Scheduler: Waiting for completed trials (for 17 sec, currently running trials: 3).
[INFO 09-28 15:51:05] Scheduler: Retrieved COMPLETED trials: [43].
[INFO 09-28 15:51:05] Scheduler: Fetching data for trials: [43].
[INFO 09-28 15:51:09] Scheduler: Running trials [46]...
[INFO 09-28 15:51:10] Scheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 09-28 15:51:10] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 09-28 15:51:11] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 09-28 15:51:13] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 09-28 15:51:15] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 3).
[INFO 09-28 15:51:18] Scheduler: Waiting for completed trials (for 5 sec, currently running trials: 3).
[INFO 09-28 15:51:23] Scheduler: Waiting for completed trials (for 7 sec, currently running trials: 3).
[INFO 09-28 15:51:31] Scheduler: Waiting for completed trials (for 11 sec, currently running trials: 3).
[INFO 09-28 15:51:43] Scheduler: Retrieved COMPLETED trials: 44 - 45.
[INFO 09-28 15:51:43] Scheduler: Fetching data for trials: 44 - 45.
[INFO 09-28 15:51:48] Scheduler: Running trials [47]...
[INFO 09-28 15:51:50] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 2).
[INFO 09-28 15:51:51] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 2).
[INFO 09-28 15:51:52] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 2).
[INFO 09-28 15:51:54] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 2).
[INFO 09-28 15:51:58] Scheduler: Waiting for completed trials (for 5 sec, currently running trials: 2).
[INFO 09-28 15:52:03] Scheduler: Waiting for completed trials (for 7 sec, currently running trials: 2).
[INFO 09-28 15:52:10] Scheduler: Waiting for completed trials (for 11 sec, currently running trials: 2).
[INFO 09-28 15:52:22] Scheduler: Waiting for completed trials (for 17 sec, currently running trials: 2).
[INFO 09-28 15:52:39] Scheduler: Retrieved COMPLETED trials: [46].
[INFO 09-28 15:52:39] Scheduler: Fetching data for trials: [46].
[INFO 09-28 15:52:39] Scheduler: Done submitting trials, waiting for remaining 1 running trials...
[INFO 09-28 15:52:39] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 1).
[INFO 09-28 15:52:40] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 1).
[INFO 09-28 15:52:41] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 1).
[INFO 09-28 15:52:44] Scheduler: Waiting for completed trials (for 3 sec, currently running trials: 1).
[INFO 09-28 15:52:47] Scheduler: Waiting for completed trials (for 5 sec, currently running trials: 1).
[INFO 09-28 15:52:52] Scheduler: Waiting for completed trials (for 7 sec, currently running trials: 1).
[INFO 09-28 15:53:00] Scheduler: Retrieved COMPLETED trials: [47].
[INFO 09-28 15:53:00] Scheduler: Fetching data for trials: [47].

OptimizationResult()

Evaluating the results

We can now inspect the result of the optimization using helper functions and visualizations included with Ax.

First, we generate a dataframe with a summary of the results of the experiment. Each row in this dataframe corresponds to a trial (that is, a training job that was run), and contains information on the status of the trial, the parameter configuration that was evaluated, and the metric values that were observed. This provides an easy way to sanity check the optimization.

from ax.service.utils.report_utils import exp_to_df

df = exp_to_df(experiment)
df.head(10)
num_params val_acc trial_index arm_name hidden_size_1 hidden_size_2 learning_rate epochs dropout batch_size trial_status generation_method
47 NaN NaN 0 0_0 98 16 0.000682 4 0.050487 32 FAILED Sobol
2 73472.0 0.9650 1 1_0 86 63 0.001096 4 0.170217 256 COMPLETED Sobol
14 81560.0 0.9310 2 2_0 90 110 0.000258 1 0.301337 32 COMPLETED Sobol
26 74715.0 0.9564 3 3_0 85 85 0.000174 4 0.132252 64 COMPLETED Sobol
38 98948.0 0.9196 4 4_0 122 25 0.000132 3 0.270602 128 COMPLETED Sobol
42 51071.0 0.9428 5 5_0 63 23 0.004674 1 0.147705 128 COMPLETED Sobol
43 37608.0 0.9323 6 6_0 42 90 0.000200 3 0.129751 128 COMPLETED Sobol
44 57365.0 0.9399 7 7_0 71 21 0.002609 1 0.204063 256 COMPLETED Sobol
45 30656.0 0.9315 8 8_0 38 18 0.005043 3 0.293701 256 COMPLETED Sobol
46 25698.0 0.9530 9 9_0 31 34 0.001497 4 0.209055 64 COMPLETED Sobol


We can also visualize the Pareto frontier of tradeoffs between the validation accuracy and the number of model parameters.

Tip

Ax uses Plotly to produce interactive plots, which allow you to do things like zoom, crop, or hover in order to view details of components of the plot. Try it out, and take a look at the visualization tutorial if you’d like to learn more).

The final optimization results are shown in the figure below where the color corresponds to the iteration number for each trial. We see that our method was able to successfully explore the trade-offs and found both large models with high validation accuracy as well as small models with comparatively lower validation accuracy.

from ax.service.utils.report_utils import _pareto_frontier_scatter_2d_plotly

_pareto_frontier_scatter_2d_plotly(experiment)


To better understand what our surrogate models have learned about the black box objectives, we can take a look at the leave-one-out cross validation results. Since our models are Gaussian Processes, they not only provide point predictions but also uncertainty estimates about these predictions. A good model means that the predicted means (the points in the figure) are close to the 45 degree line and that the confidence intervals cover the 45 degree line with the expected frequency (here we use 95% confidence intervals, so we would expect them to contain the true observation 95% of the time).

As the figures below show, the model size (num_params) metric is much easier to model than the validation accuracy (val_acc) metric.

from ax.modelbridge.cross_validation import compute_diagnostics, cross_validate
from ax.plot.diagnostic import interact_cross_validation_plotly
from ax.utils.notebook.plotting import init_notebook_plotting, render

cv = cross_validate(model=gs.model)  # The surrogate model is stored on the GenerationStrategy
compute_diagnostics(cv)

interact_cross_validation_plotly(cv)


We can also make contour plots to better understand how the different objectives depend on two of the input parameters. In the figure below, we show the validation accuracy predicted by the model as a function of the two hidden sizes. The validation accuracy clearly increases as the hidden sizes increase.

from ax.plot.contour import interact_contour_plotly

interact_contour_plotly(model=gs.model, metric_name="val_acc")