Source code for torchrl.record.loggers.wandb
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib.util
import os
import warnings
from typing import Dict, Optional, Sequence, Union
from torch import Tensor
from .common import Logger
_has_wandb = importlib.util.find_spec("wandb") is not None
_has_omegaconf = importlib.util.find_spec("omegaconf") is not None
[docs]class WandbLogger(Logger):
"""Wrapper for the wandb logger.
The keyword arguments are mainly based on the :func:`wandb.init` kwargs.
See the doc `here <https://docs.wandb.ai/ref/python/init>`__.
Args:
exp_name (str): The name of the experiment.
offline (bool, optional): if ``True``, the logs will be stored locally
only. Defaults to ``False``.
save_dir (path, optional): the directory where to save data. Exclusive with
``log_dir``.
log_dir (path, optional): the directory where to save data. Exclusive with
``save_dir``.
id (str, optional): A unique ID for this run, used for resuming.
It must be unique in the project, and if you delete a run you can't reuse the ID.
project (str, optional): The name of the project where you're sending
the new run. If the project is not specified, the run is put in
an ``"Uncategorized"`` project.
**kwargs: Extra keyword arguments for ``wandb.init``. See relevant page for
more info.
"""
@classmethod
def __new__(cls, *args, **kwargs):
cls._prev_video_step = -1
return super().__new__(cls)
def __init__(
self,
exp_name: str,
offline: bool = False,
save_dir: str = None,
id: str = None,
project: str = None,
**kwargs,
) -> None:
if not _has_wandb:
raise ImportError("wandb could not be imported")
log_dir = kwargs.pop("log_dir", None)
self.offline = offline
if save_dir and log_dir:
raise ValueError(
"log_dir and save_dir point to the same value in "
"WandbLogger. Both cannot be specified."
)
save_dir = save_dir if save_dir and not log_dir else log_dir
self.save_dir = save_dir
self.id = id
self.project = project
self._wandb_kwargs = {
"name": exp_name,
"dir": save_dir,
"id": id,
"project": project,
"resume": "allow",
**kwargs,
}
self._has_imported_wandb = False
super().__init__(exp_name=exp_name, log_dir=save_dir)
if self.offline:
os.environ["WANDB_MODE"] = "dryrun"
self._has_imported_moviepy = False
self._has_imported_omgaconf = False
self.video_log_counter = 0
def _create_experiment(self) -> "WandbLogger":
"""Creates a wandb experiment.
Args:
exp_name (str): The name of the experiment.
Returns:
WandbLogger: The wandb experiment logger.
"""
if not _has_wandb:
raise ImportError("Wandb is not installed")
import wandb
if self.offline:
os.environ["WANDB_MODE"] = "dryrun"
return wandb.init(**self._wandb_kwargs)
def log_scalar(self, name: str, value: float, step: Optional[int] = None) -> None:
"""Logs a scalar value to wandb.
Args:
name (str): The name of the scalar.
value (:obj:`float`): The value of the scalar.
step (int, optional): The step at which the scalar is logged.
Defaults to None.
"""
if step is not None:
self.experiment.log({name: value, "trainer/step": step})
else:
self.experiment.log({name: value})
def log_video(self, name: str, video: Tensor, **kwargs) -> None:
"""Log videos inputs to wandb.
Args:
name (str): The name of the video.
video (Tensor): The video to be logged.
**kwargs: Other keyword arguments. By construction, log_video
supports 'step' (integer indicating the step index), 'format'
(default is 'mp4') and 'fps' (default: 6). Other kwargs are
passed as-is to the :obj:`experiment.log` method.
"""
import wandb
# check for correct format of the video tensor ((N), T, C, H, W)
# check that the color channel (C) is either 1 or 3
if video.dim() != 5 or video.size(dim=2) not in {1, 3}:
raise Exception(
"Wrong format of the video tensor. Should be ((N), T, C, H, W)"
)
if not self._has_imported_moviepy:
try:
import moviepy # noqa
self._has_imported_moviepy = True
except ImportError:
raise Exception(
"moviepy not found, videos cannot be logged with TensorboardLogger"
)
self.video_log_counter += 1
fps = kwargs.pop("fps", 6)
step = kwargs.pop("step", None)
format = kwargs.pop("format", "mp4")
if step not in (None, self._prev_video_step, self._prev_video_step + 1):
warnings.warn(
"when using step with wandb_logger.log_video, it is expected "
"that the step is equal to the previous step or that value incremented "
f"by one. Got step={step} but previous value was {self._prev_video_step}. "
f"The step value will be set to {self._prev_video_step+1}. This warning will "
f"be silenced from now on but the values will keep being incremented."
)
step = self._prev_video_step + 1
self._prev_video_step = step if step is not None else self._prev_video_step + 1
self.experiment.log(
{name: wandb.Video(video, fps=fps, format=format)},
# step=step,
**kwargs,
)
def log_hparams(self, cfg: Union["DictConfig", Dict]) -> None: # noqa: F821
"""Logs the hyperparameters of the experiment.
Args:
cfg (DictConfig or dict): The configuration of the experiment.
"""
if type(cfg) is not dict and _has_omegaconf:
if not _has_omegaconf:
raise ImportError(
"OmegaConf could not be imported. "
"Cannot log hydra configs without OmegaConf."
)
from omegaconf import OmegaConf
cfg = OmegaConf.to_container(cfg, resolve=True)
self.experiment.config.update(cfg, allow_val_change=True)
def __repr__(self) -> str:
return f"WandbLogger(experiment={self.experiment.__repr__()})"
def log_histogram(self, name: str, data: Sequence, **kwargs):
"""Add histogram to log.
Args:
name (str): Data identifier
data (torch.Tensor, numpy.ndarray, or string/blobname): Values to build histogram
Keyword Args:
step (int): Global step value to record
bins (str): One of {‘tensorflow’,’auto’, ‘fd’, …}. This determines how the bins are made. You can find other options in: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html
"""
import wandb
num_bins = kwargs.pop("bins", None)
step = kwargs.pop("step", None)
extra_kwargs = {}
if step is not None:
extra_kwargs["trainer/step"] = step
self.experiment.log(
{name: wandb.Histogram(data, num_bins=num_bins), **extra_kwargs}
)