Skip to content

[BugFix] Minor fixes to wandb logger #2999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/test_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def test_log_scalar(self, steps, wandb_logger):
value=scalar_value,
name=scalar_name,
step=steps[i] if steps else None,
commit=True,
)

assert wandb_logger.experiment.summary["foo"] == values[-1].item()
Expand Down Expand Up @@ -315,9 +316,8 @@ def test_log_video(self, wandb_logger):
assert video_4fps_size > video_16fps_size, (video_4fps_size, video_16fps_size)

# check that we catch the error in case the format of the tensor is wrong
video_wrong_format = torch.zeros(64, 2, 32, 32)
video_wrong_format = video_wrong_format[None, :]
with pytest.raises(Exception):
video_wrong_format = torch.zeros(2, 32, 32)
with pytest.raises(ValueError, match="Video must be at least"):
wandb_logger.log_video(
name="foo",
video=video_wrong_format,
Expand Down
53 changes: 9 additions & 44 deletions torchrl/record/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import importlib.util

import os
import warnings
from typing import Sequence

from torch import Tensor
Expand Down Expand Up @@ -47,7 +46,6 @@ class WandbLogger(Logger):

@classmethod
def __new__(cls, *args, **kwargs):
cls._prev_video_step = -1
return super().__new__(cls)

def __init__(
Expand Down Expand Up @@ -84,25 +82,19 @@ def __init__(
"resume": "allow",
**kwargs,
}
self._has_imported_wandb = False

super().__init__(exp_name=exp_name, log_dir=save_dir)
if self.offline:
os.environ["WANDB_MODE"] = "dryrun"

self._has_imported_moviepy = False

self._has_imported_omgaconf = False

self.video_log_counter = 0

def _create_experiment(self) -> WandbLogger:
def _create_experiment(self):
"""Creates a wandb experiment.
Args:
exp_name (str): The name of the experiment.
Returns:
WandbLogger: The wandb experiment logger.
A wandb.Experiment object.
"""
if not _has_wandb:
raise ImportError("Wandb is not installed")
Expand All @@ -113,19 +105,20 @@ def _create_experiment(self) -> WandbLogger:

return wandb.init(**self._wandb_kwargs)

def log_scalar(self, name: str, value: float, step: int | None = None) -> None:
def log_scalar(
self, name: str, value: float, step: int | None = None, commit: bool = False
) -> None:
"""Logs a scalar value to wandb.
Args:
name (str): The name of the scalar.
value (float): The value of the scalar.
step (int, optional): The step at which the scalar is logged.
Defaults to None.
commit: If true, data for current step is assumed to be final (and
no further data for this step should be logged).
"""
if step is not None:
self.experiment.log({name: value, "trainer/step": step})
else:
self.experiment.log({name: value})
self.experiment.log({name: value}, step=step, commit=commit)

def log_video(self, name: str, video: Tensor, **kwargs) -> None:
"""Log videos inputs to wandb.
Expand All @@ -140,38 +133,10 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None:
"""
import wandb

# check for correct format of the video tensor ((N), T, C, H, W)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this was mistakenly copied from tensorbaord logger. I dont see why its needed here.

# check that the color channel (C) is either 1 or 3
if video.dim() != 5 or video.size(dim=2) not in {1, 3}:
raise Exception(
"Wrong format of the video tensor. Should be ((N), T, C, H, W)"
)
if not self._has_imported_moviepy:
try:
import moviepy # noqa

self._has_imported_moviepy = True
except ImportError:
raise Exception(
"moviepy not found, videos cannot be logged with TensorboardLogger"
)
self.video_log_counter += 1
fps = kwargs.pop("fps", self.video_fps)
step = kwargs.pop("step", None)
format = kwargs.pop("format", "mp4")
if step not in (None, self._prev_video_step, self._prev_video_step + 1):
warnings.warn(
"when using step with wandb_logger.log_video, it is expected "
"that the step is equal to the previous step or that value incremented "
f"by one. Got step={step} but previous value was {self._prev_video_step}. "
f"The step value will be set to {self._prev_video_step+1}. This warning will "
f"be silenced from now on but the values will keep being incremented."
)
step = self._prev_video_step + 1
self._prev_video_step = step if step is not None else self._prev_video_step + 1
self.experiment.log(
{name: wandb.Video(video, fps=fps, format=format)},
# step=step,
**kwargs,
)

Expand Down
Loading