Shortcuts

Source code for torchcodec.samplers._time_based

from typing import Literal, Optional

import torch

from torchcodec import FrameBatch
from torchcodec.samplers._common import (
    _FRAMEBATCH_RETURN_DOCS,
    _POLICY_FUNCTION_TYPE,
    _POLICY_FUNCTIONS,
    _reshape_4d_framebatch_into_5d,
    _validate_common_params,
)


def _validate_params_time_based(
    *,
    decoder,
    num_clips,
    seconds_between_clip_starts,
    seconds_between_frames,
):

    if (num_clips is None and seconds_between_clip_starts is None) or (
        num_clips is not None and seconds_between_clip_starts is not None
    ):
        raise ValueError("This is internal only and should never happen.")

    if seconds_between_clip_starts is not None and seconds_between_clip_starts <= 0:
        raise ValueError(
            f"seconds_between_clip_starts ({seconds_between_clip_starts}) must be > 0"
        )

    if num_clips is not None and num_clips <= 0:
        raise ValueError(f"num_clips ({num_clips}) must be > 0")

    if decoder.metadata.average_fps is None:
        raise ValueError(
            "Could not infer average fps from video metadata. "
            "Try using an index-based sampler instead."
        )
    if (
        decoder.metadata.end_stream_seconds is None
        or decoder.metadata.begin_stream_seconds is None
    ):
        raise ValueError(
            "Could not infer stream end and start from video metadata. "
            "Try using an index-based sampler instead."
        )

    average_frame_duration_seconds = 1 / decoder.metadata.average_fps
    if seconds_between_frames is None:
        seconds_between_frames = average_frame_duration_seconds
    elif seconds_between_frames <= 0:
        raise ValueError(
            f"seconds_between_clip_starts ({seconds_between_clip_starts}) must be > 0, got"
        )

    return seconds_between_frames


def _validate_sampling_range_time_based(
    *,
    num_frames_per_clip,
    seconds_between_frames,
    sampling_range_start,
    sampling_range_end,
    begin_stream_seconds,
    end_stream_seconds,
):

    if sampling_range_start is None:
        sampling_range_start = begin_stream_seconds
    else:
        if sampling_range_start < begin_stream_seconds:
            raise ValueError(
                f"sampling_range_start ({sampling_range_start}) must be at least {begin_stream_seconds}"
            )
        if sampling_range_start >= end_stream_seconds:
            raise ValueError(
                f"sampling_range_start ({sampling_range_start}) must be smaller than {end_stream_seconds}"
            )

    if sampling_range_end is None:
        # We allow a clip to start anywhere within
        # [sampling_range_start, sampling_range_end)
        # When sampling_range_end is None, we want to automatically set it to
        # the largest possible value such that the sampled frames in any clip
        # are within the bounds of the video duration (in other words, we don't
        # want to have to resort to the `policy`).
        # I.e. we want to guarantee that for all frames in any clip we have
        # pts < end_stream_seconds.
        #
        # The frames of a clip will be sampled at the following pts:
        # clip_timestamps = [
        #  clip_start + 0 * seconds_between_frames,
        #  clip_start + 1 * seconds_between_frames,
        #  clip_start + 2 * seconds_between_frames,
        #  ...
        #  clip_start + (num_frames_per_clip - 1) * seconds_between_frames,
        # ]
        # To guarantee that any such value is < end_stream_seconds, we only need
        # to guarantee that
        # clip_start < end_stream_seconds - (num_frames_per_clip - 1) * seconds_between_frames
        #
        # So that's the value of sampling_range_end we want to use.
        sampling_range_end = (
            end_stream_seconds - (num_frames_per_clip - 1) * seconds_between_frames
        )
    elif sampling_range_end <= begin_stream_seconds:
        raise ValueError(
            f"sampling_range_end ({sampling_range_end}) must be at least {begin_stream_seconds}"
        )

    if sampling_range_start >= sampling_range_end:
        raise ValueError(
            f"sampling_range_start ({sampling_range_start}) must be smaller than sampling_range_end ({sampling_range_end})"
        )

    sampling_range_end = min(sampling_range_end, end_stream_seconds)

    return sampling_range_start, sampling_range_end


def _build_all_clips_timestamps(
    *,
    clip_start_seconds: torch.Tensor,  # 1D float tensor
    num_frames_per_clip: int,
    seconds_between_frames: float,
    end_stream_seconds: float,
    policy_fun: _POLICY_FUNCTION_TYPE,
) -> list[float]:

    all_clips_timestamps: list[float] = []
    for start_seconds in clip_start_seconds:
        clip_timestamps = [
            timestamp
            for i in range(num_frames_per_clip)
            if (timestamp := start_seconds + i * seconds_between_frames)
            < end_stream_seconds
        ]

        if len(clip_timestamps) < num_frames_per_clip:
            clip_timestamps = policy_fun(clip_timestamps, num_frames_per_clip)
        all_clips_timestamps += clip_timestamps

    return all_clips_timestamps


def _generic_time_based_sampler(
    kind: Literal["random", "regular"],
    decoder,
    *,
    num_clips: Optional[int],  # mutually exclusive with seconds_between_clip_starts
    seconds_between_clip_starts: Optional[float],
    num_frames_per_clip: int,
    seconds_between_frames: Optional[float],
    # None means "begining", which may not always be 0
    sampling_range_start: Optional[float],
    sampling_range_end: Optional[float],  # interval is [start, end).
    policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
) -> FrameBatch:
    # Note: *everywhere*, sampling_range_end denotes the upper bound of where a
    # clip can start. This is an *open* upper bound, i.e. we will make sure no
    # clip starts exactly at (or above) sampling_range_end.

    _validate_common_params(
        decoder=decoder,
        num_frames_per_clip=num_frames_per_clip,
        policy=policy,
    )

    seconds_between_frames = _validate_params_time_based(
        decoder=decoder,
        num_clips=num_clips,
        seconds_between_clip_starts=seconds_between_clip_starts,
        seconds_between_frames=seconds_between_frames,
    )

    sampling_range_start, sampling_range_end = _validate_sampling_range_time_based(
        num_frames_per_clip=num_frames_per_clip,
        seconds_between_frames=seconds_between_frames,
        sampling_range_start=sampling_range_start,
        sampling_range_end=sampling_range_end,
        begin_stream_seconds=decoder.metadata.begin_stream_seconds,
        end_stream_seconds=decoder.metadata.end_stream_seconds,
    )

    if kind == "random":
        assert num_clips is not None  # appease type-checker
        sampling_range_width = sampling_range_end - sampling_range_start
        # torch.rand() returns in [0, 1)
        # which ensures all clip starts are < sampling_range_end
        clip_start_seconds = (
            torch.rand(num_clips) * sampling_range_width + sampling_range_start
        )
    else:
        assert seconds_between_clip_starts is not None  # appease type-checker
        clip_start_seconds = torch.arange(
            sampling_range_start,
            sampling_range_end,  # excluded
            seconds_between_clip_starts,
        )
        num_clips = len(clip_start_seconds)

    all_clips_timestamps = _build_all_clips_timestamps(
        clip_start_seconds=clip_start_seconds,
        num_frames_per_clip=num_frames_per_clip,
        seconds_between_frames=seconds_between_frames,
        end_stream_seconds=decoder.metadata.end_stream_seconds,
        policy_fun=_POLICY_FUNCTIONS[policy],
    )

    frames = decoder.get_frames_played_at(seconds=all_clips_timestamps)
    return _reshape_4d_framebatch_into_5d(
        frames=frames,
        num_clips=num_clips,
        num_frames_per_clip=num_frames_per_clip,
    )


[docs]def clips_at_random_timestamps( decoder, *, num_clips: int = 1, num_frames_per_clip: int = 1, seconds_between_frames: Optional[float] = None, # None means "begining", which may not always be 0 sampling_range_start: Optional[float] = None, sampling_range_end: Optional[float] = None, # interval is [start, end). policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", ) -> FrameBatch: # See docstring below return _generic_time_based_sampler( kind="random", decoder=decoder, num_clips=num_clips, seconds_between_clip_starts=None, num_frames_per_clip=num_frames_per_clip, seconds_between_frames=seconds_between_frames, sampling_range_start=sampling_range_start, sampling_range_end=sampling_range_end, policy=policy, )
[docs]def clips_at_regular_timestamps( decoder, *, seconds_between_clip_starts: float, num_frames_per_clip: int = 1, seconds_between_frames: Optional[float] = None, # None means "begining", which may not always be 0 sampling_range_start: Optional[float] = None, sampling_range_end: Optional[float] = None, # interval is [start, end). policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", ) -> FrameBatch: # See docstring below return _generic_time_based_sampler( kind="regular", decoder=decoder, num_clips=None, seconds_between_clip_starts=seconds_between_clip_starts, num_frames_per_clip=num_frames_per_clip, seconds_between_frames=seconds_between_frames, sampling_range_start=sampling_range_start, sampling_range_end=sampling_range_end, policy=policy, )
_COMMON_DOCS = """ {maybe_note} Args: decoder (VideoDecoder): The :class:`~torchcodec.decoders.VideoDecoder` instance to sample clips from. {num_clips_or_seconds_between_clip_starts} num_frames_per_clip (int, optional): The number of frames per clips. Default: 1. seconds_between_frames (float or None, optional): The time (in seconds) between each frame within a clip. More accurately, this defines the time between the *frame sampling point*, i.e. the timestamps at which we sample the frames. Because frames span intervals in time , the resulting start of frames within a clip may not be exactly spaced by ``seconds_between_frames`` - but on average, they will be. Default is None, which is set to the average frame duration (``1/average_fps``). sampling_range_start (float or None, optional): The start of the sampling range, which defines the first timestamp (in seconds) that a clip may *start* at. Default: None, which corresponds to the start of the video. (Note: some videos start at negative values, which is why the default is not 0). sampling_range_end (float or None, optional): The end of the sampling range, which defines the last timestamp (in seconds) that a clip may *start* at. This value is exclusive, i.e. a clip may only start within [``sampling_range_start``, ``sampling_range_end``). If None (default), the value is set automatically such that the clips never span beyond the end of the video, i.e. it is set to ``end_video_seconds - (num_frames_per_clip - 1) * seconds_between_frames``. When a clip spans beyond the end of the video, the ``policy`` parameter defines how to construct such clip. policy (str, optional): Defines how to construct clips that span beyond the end of the video. This is best described with an example: assuming the last valid (seekable) timestamp in a video is 10.9, and a clip was sampled to start at timestamp 10.5, with ``num_frames_per_clip=5`` and ``seconds_between_frames=0.2``, the sampling timestamps of the frames in the clip are supposed to be [10.5, 10.7, 10.9, 11.1, 11.2]. But 11.1 and 11.2 are invalid timestamps, so the ``policy`` parameter defines how to replace those frames, with valid sampling timestamps: - "repeat_last": repeats the last valid frame of the clip. We would get frames sampled at timestamps [10.5, 10.7, 10.9, 10.9, 10.9]. - "wrap": wraps around to the beginning of the clip. We would get frames sampled at timestamps [10.5, 10.7, 10.9, 10.5, 10.7]. - "error": raises an error. Default is "repeat_last". Note that when ``sampling_range_end=None`` (default), this policy parameter is unlikely to be relevant. {return_docs} """ _NUM_CLIPS_DOCS = """ num_clips (int, optional): The number of clips to return. Default: 1. """ clips_at_random_timestamps.__doc__ = f"""Sample :term:`clips` at random timestamps. {_COMMON_DOCS.format(maybe_note="", num_clips_or_seconds_between_clip_starts=_NUM_CLIPS_DOCS, return_docs=_FRAMEBATCH_RETURN_DOCS)} """ _SECONDS_BETWEEN_CLIP_STARTS = """ seconds_between_clip_starts (float): The space (in seconds) between each clip start. """ _NOTE_DOCS = """ .. note:: For consistency with existing sampling APIs (such as torchvision), this sampler takes a ``seconds_between_clip_starts`` parameter instead of ``num_clips``. If you find that supporting ``num_clips`` would be useful, please let us know by `opening a feature request <https://github.com/pytorch/torchcodec/issues?q=is:open+is:issue>`_. """ clips_at_regular_timestamps.__doc__ = f"""Sample :term:`clips` at regular (equally-spaced) timestamps. {_COMMON_DOCS.format(maybe_note=_NOTE_DOCS, num_clips_or_seconds_between_clip_starts=_SECONDS_BETWEEN_CLIP_STARTS, return_docs=_FRAMEBATCH_RETURN_DOCS)} """

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources