Shortcuts

Source code for torchrl.record.loggers.wandb

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import importlib.util

import os
import warnings
from typing import Sequence

from torch import Tensor

from .common import Logger

_has_wandb = importlib.util.find_spec("wandb") is not None
_has_omegaconf = importlib.util.find_spec("omegaconf") is not None


[docs]class WandbLogger(Logger): """Wrapper for the wandb logger. The keyword arguments are mainly based on the :func:`wandb.init` kwargs. See the doc `here <https://docs.wandb.ai/ref/python/init>`__. Args: exp_name (str): The name of the experiment. offline (bool, optional): if ``True``, the logs will be stored locally only. Defaults to ``False``. save_dir (path, optional): the directory where to save data. Exclusive with ``log_dir``. log_dir (path, optional): the directory where to save data. Exclusive with ``save_dir``. id (str, optional): A unique ID for this run, used for resuming. It must be unique in the project, and if you delete a run you can't reuse the ID. project (str, optional): The name of the project where you're sending the new run. If the project is not specified, the run is put in an ``"Uncategorized"`` project. Keyword Args: fps (int, optional): Number of frames per second when recording videos. Defaults to ``30``. **kwargs: Extra keyword arguments for ``wandb.init``. See relevant page for more info. """ @classmethod def __new__(cls, *args, **kwargs): cls._prev_video_step = -1 return super().__new__(cls) def __init__( self, exp_name: str, offline: bool = False, save_dir: str = None, id: str = None, project: str = None, *, video_fps: int = 32, **kwargs, ) -> None: if not _has_wandb: raise ImportError("wandb could not be imported") log_dir = kwargs.pop("log_dir", None) self.offline = offline if save_dir and log_dir: raise ValueError( "log_dir and save_dir point to the same value in " "WandbLogger. Both cannot be specified." ) save_dir = save_dir if save_dir and not log_dir else log_dir self.save_dir = save_dir self.id = id self.project = project self.video_fps = video_fps self._wandb_kwargs = { "name": exp_name, "dir": save_dir, "id": id, "project": project, "resume": "allow", **kwargs, } self._has_imported_wandb = False super().__init__(exp_name=exp_name, log_dir=save_dir) if self.offline: os.environ["WANDB_MODE"] = "dryrun" self._has_imported_moviepy = False self._has_imported_omgaconf = False self.video_log_counter = 0 def _create_experiment(self) -> WandbLogger: """Creates a wandb experiment. Args: exp_name (str): The name of the experiment. Returns: WandbLogger: The wandb experiment logger. """ if not _has_wandb: raise ImportError("Wandb is not installed") import wandb if self.offline: os.environ["WANDB_MODE"] = "dryrun" return wandb.init(**self._wandb_kwargs) def log_scalar(self, name: str, value: float, step: int | None = None) -> None: """Logs a scalar value to wandb. Args: name (str): The name of the scalar. value (:obj:`float`): The value of the scalar. step (int, optional): The step at which the scalar is logged. Defaults to None. """ if step is not None: self.experiment.log({name: value, "trainer/step": step}) else: self.experiment.log({name: value}) def log_video(self, name: str, video: Tensor, **kwargs) -> None: """Log videos inputs to wandb. Args: name (str): The name of the video. video (Tensor): The video to be logged. **kwargs: Other keyword arguments. By construction, log_video supports 'step' (integer indicating the step index), 'format' (default is 'mp4') and 'fps' (defaults to ``self.video_fps``). Other kwargs are passed as-is to the :obj:`experiment.log` method. """ import wandb # check for correct format of the video tensor ((N), T, C, H, W) # check that the color channel (C) is either 1 or 3 if video.dim() != 5 or video.size(dim=2) not in {1, 3}: raise Exception( "Wrong format of the video tensor. Should be ((N), T, C, H, W)" ) if not self._has_imported_moviepy: try: import moviepy # noqa self._has_imported_moviepy = True except ImportError: raise Exception( "moviepy not found, videos cannot be logged with TensorboardLogger" ) self.video_log_counter += 1 fps = kwargs.pop("fps", self.video_fps) step = kwargs.pop("step", None) format = kwargs.pop("format", "mp4") if step not in (None, self._prev_video_step, self._prev_video_step + 1): warnings.warn( "when using step with wandb_logger.log_video, it is expected " "that the step is equal to the previous step or that value incremented " f"by one. Got step={step} but previous value was {self._prev_video_step}. " f"The step value will be set to {self._prev_video_step+1}. This warning will " f"be silenced from now on but the values will keep being incremented." ) step = self._prev_video_step + 1 self._prev_video_step = step if step is not None else self._prev_video_step + 1 self.experiment.log( {name: wandb.Video(video, fps=fps, format=format)}, # step=step, **kwargs, ) def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821 """Logs the hyperparameters of the experiment. Args: cfg (DictConfig or dict): The configuration of the experiment. """ if type(cfg) is not dict and _has_omegaconf: if not _has_omegaconf: raise ImportError( "OmegaConf could not be imported. " "Cannot log hydra configs without OmegaConf." ) from omegaconf import OmegaConf cfg = OmegaConf.to_container(cfg, resolve=True) self.experiment.config.update(cfg, allow_val_change=True) def __repr__(self) -> str: return f"WandbLogger(experiment={self.experiment.__repr__()})" def log_histogram(self, name: str, data: Sequence, **kwargs): """Add histogram to log. Args: name (str): Data identifier data (torch.Tensor, numpy.ndarray, or string/blobname): Values to build histogram Keyword Args: step (int): Global step value to record bins (str): One of {‘tensorflow’,’auto’, ‘fd’, …}. This determines how the bins are made. You can find other options in: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html """ import wandb num_bins = kwargs.pop("bins", None) step = kwargs.pop("step", None) extra_kwargs = {} if step is not None: extra_kwargs["trainer/step"] = step self.experiment.log( {name: wandb.Histogram(data, num_bins=num_bins), **extra_kwargs} )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources