Using Optuna to Optimize PyTorch Lightning Hyperparameters

From: https://medium.com/optuna/using-optuna-to-optimize-pytorch-lightning-hyperparameters-d9e04a481585

This post uses pytorch-lightning v0.6.0 (PyTorch v1.3.1)and optuna v1.1.0.

PyTorch Lightning + Optuna!

Optuna is a hyperparameter optimization framework applicable to machine learning frameworks and black-box optimization solvers. PyTorch Lightning provides a lightweight PyTorch wrapper for better scaling with less code. Combining the two of them allows for automatic tuning of hyperparameters to find the best performing models.

Creating the Objective Function

Optuna is a black-box optimizer, which means it needs an objectivefunction, which returns a numerical value to evaluate the performance of the hyperparameters, and decide where to sample in upcoming trials.

In our example, we will be doing this for identifying MNIST characters from the Optuna GitHub examples folder. In this case, the objective function starts like this:

def objective(trial):
    # PyTorch Lightning will try to restore model parameters from previous trials if checkpoint
    # filenames match. Therefore, the filenames for each trial must be made unique.
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        os.path.join(MODEL_DIR, "trial_{}".format(trial.number)), monitor="accuracy"
    )

    # The default logger in PyTorch Lightning writes to event files to be consumed by
    # TensorBoard. We create a simple logger instead that holds the log in memory so that the
    # final accuracy can be obtained after optimization. When using the default logger, the
    # final accuracy could be stored in an attribute of the `Trainer` instead.
    logger = DictLogger(trial.number)

    trainer = pl.Trainer(
        logger=logger,
        val_percent_check=PERCENT_TEST_EXAMPLES,
        checkpoint_callback=checkpoint_callback,
        max_epochs=EPOCHS,
        gpus=0 if torch.cuda.is_available() else None,
        early_stop_callback=PyTorchLightningPruningCallback(trial, monitor="accuracy"),
    )

    model = LightningNet(trial)
    trainer.fit(model)

    return logger.metrics[-1]["accuracy"]

Notice that the objective function is passed an Optuna specific argument of trial. This object is passed to the objective function to be used to specify which hyperparameters should be tuned. This returns the accuracy of the model as return logger.metrics[-1][“accuracy”], which is used by Optuna as feedback on the performance of the trial.

Defining the hyperparameters to be tuned

Similar to how PyTorch uses Eager execution, Optuna allows you to define the kinds and ranges of hyperparameters you want to tune directly within your code using the trial object. This saves the effort of learning specialized syntax for hyperparameters, and also means you can use normal Python code to loop through or define your hyperparameters.

Optuna supports a variety of hyperparameter settings, which can be used to optimize floats, integers, or discrete categorical values. Numerical values can be suggested from a logarithmic continuum as well. In our MNIST example, we optimize the hyperparameters here:

class Net(nn.Module):
    def __init__(self, trial):
        super(Net, self).__init__()
        self.layers = []
        self.dropouts = []
        
        # We optimize the number of layers, hidden untis in each layer and drouputs.
        n_layers = trial.suggest_int("n_layers", 1, 3)
        dropout = trial.suggest_uniform("dropout", 0.2, 0.5)
        input_dim = 28 * 28
        for i in range(n_layers):
            output_dim = int(trial.suggest_loguniform("n_units_l{}".format(i), 4, 128))
            self.layers.append(nn.Linear(input_dim, output_dim))
            self.dropouts.append(nn.Dropout(dropout))
            input_dim = output_dim

The number of layers to be tuned is given from trial.suggest_int(“n_layers”, 1, 3), which gives an integer value from one to three, which will be labelled in Optuna as n_layers.

The dropout percentage is defined by trial.suggest_uniform(“dropout”, 0.2, 0.5), which gives a float value between 0.2 and 0.5.

For hyperparameters which should vary by orders of magnitude, such as learning rates, use something like trial.suggest_loguniform('learning_rate', 1e-5, 1000), which will vary the values from .00001 to 0.1.

Categorical selection from a list is possible with trial.suggest_categorical(‘optimizer’, [‘SGD’, ‘Adam’]).

Running the Trials

The default sampler in Optuna Tree-structured Parzen Estimater (TPE), which is a form of Bayesian Optimization. Optuna uses TPE to search more efficiently than a random search, by choosing points closer to previous good results.

To run the trials, create a study object, which sets the direction of optimization (“maximize” or “minimize”), along with other settings. Then, the study object run with optimize(objective, n_trials=100, timeout=600), to do one hundred trials, with a timeout of one hour for frozen trials.

study = optuna.create_study(direction="maximize", pruner=pruner)
study.optimize(objective, n_trials=100, timeout=600)

print("Number of finished trials: {}".format(len(study.trials)))

print("Best trial:")
trial = study.best_trial

print("  Value: {}".format(trial.value))

print("  Params: ")
for key, value in trial.params.items():
print("    {}: {}".format(key, value))

Each trial is chosen after evaluating all the trials that have been previously done, using a sampler to make a smart guess where the best values hyperparameters can be found. Optuna provides Tree-structured Parzen Estimator (TPE) samplers, which is a kind of bayesian optimization, as the default sampler.

The best values from the trials can be accessed through study.best_trial, and other methods of viewing the trials, such as formatting in a dataframe, are available.

Pruning — Early Stopping of Poor Trials

Pruning trials is a form of early-stopping which terminates unpromising trials, so that computing time can be used for trials that show more potential. In order to do pruning, it’s necessary to open up the black-box of the Objective function some more to provide intermittent feedback on how the trial is going to Optuna, so it can compare the progress with the progress of other trials, and decide whether to stop the trial early, and provide a method to receive a method from Optuna when the trial should be terminated, and also allow the trial in session to terminate cleanly after recording the results. Fortunately, Optuna provides an integration for PyTorch Lightning (PyTorchLightingPruningCallBack) pruning that provides all of these functions.

from optuna.integration import PyTorchLightningPruningCallback
...

def objective(trial):
...

    trainer = pl.Trainer(
            logger=logger,
            val_percent_check=PERCENT_TEST_EXAMPLES,
            checkpoint_callback=checkpoint_callback,
            max_epochs=EPOCHS,
            gpus=0 if torch.cuda.is_available() else None,
            early_stop_callback=PyTorchLightningPruningCallback(trial, monitor="accuracy"),
        )

After importing the PyTorchLighntingPruningCallback, passing it as a early_stop_callback to the trainer allows Lightning to do the pruning. The monitor argument of the PyTorchLighntingPruningCallback function references the PyTorch Lightning LightningModule dictionary and could be used for other entries, such as val_loss or val_acc.

To the Future, and Beyond!

Plot Contour Visualization

For those interested, Optuna has many other features, including a visualizations, alternative samplers, optimizers, and pruning algorithms, as well as the ability to create user-defined versions as well. If you have more computing resources available, Optuna provides an easy interface for parallel trials to increase tuning speed.

Give Optuna a try!

Installation

Optuna Github

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s