Logging to Weights & Biases¶
This deep-dive will guide you through how to set up logging to Weights & Biases (W&B) in torchtune.
How to get started with W&B
How to use the
WandBLogger
How to log configs, metrics, and model checkpoints to W&B
torchtune supports logging your training runs to Weights & Biases. An example W&B workspace from a torchtune fine-tuning run can be seen in the screenshot below.
Note
You will need to install the wandb
package to use this feature.
You can install it via pip:
pip install wandb
Then you need to login with your API key using the W&B CLI:
wandb login
Metric Logger¶
The only change you need to make is to add the metric logger to your config. Weights & Biases will log the metrics and model checkpoints for you.
# enable logging to the built-in WandBLogger
metric_logger:
_component_: torchtune.training.metric_logging.WandBLogger
# the W&B project to log to
project: torchtune
We automatically grab the config from the recipe you are running and log it to W&B. You can find it in the W&B overview tab and the actual file in the Files
tab.
As a tip, you may see straggler wandb processes running in the background if your job crashes or otherwise exits without cleaning up resources. To kill these straggler processes, a command like ps
-aux | grep wandb | awk '{ print $2 }' | xargs kill
can be used.
Note
Click on this sample project to see the W&B workspace. The config used to train the models can be found here.
Logging Model Checkpoints to W&B¶
You can also log the model checkpoints to W&B by modifying the desired script save_checkpoint
method.
A suggested approach would be something like this:
def save_checkpoint(self, epoch: int) -> None:
...
## Let's save the checkpoint to W&B
## depending on the Checkpointer Class the file will be named differently
## Here is an example for the full_finetune case
checkpoint_file = Path.joinpath(
self._checkpointer._output_dir, f"torchtune_model_{epoch}"
).with_suffix(".pt")
wandb_at = wandb.Artifact(
name=f"torchtune_model_{epoch}",
type="model",
# description of the model checkpoint
description="Model checkpoint",
# you can add whatever metadata you want as a dict
metadata={
training.SEED_KEY: self.seed,
training.EPOCHS_KEY: self.epochs_run,
training.TOTAL_EPOCHS_KEY: self.total_epochs,
training.MAX_STEPS_KEY: self.max_steps_per_epoch,
}
)
wandb_at.add_file(checkpoint_file)
wandb.log_artifact(wandb_at)