Shortcuts

Decoding a video with SimpleVideoDecoder

In this example, we’ll learn how to decode a video using the SimpleVideoDecoder class.

First, a bit of boilerplate: we’ll download a video from the web, and define a plotting utility. You can ignore that part and jump right below to Creating a decoder.

from typing import Optional
import torch
import requests


# Video source: https://www.pexels.com/video/dog-eating-854132/
# License: CC0. Author: Coverr.
url = "https://videos.pexels.com/video-files/854132/854132-sd_640_360_25fps.mp4"
response = requests.get(url)
if response.status_code != 200:
    raise RuntimeError(f"Failed to download video. {response.status_code = }.")

raw_video_bytes = response.content


def plot(frames: torch.Tensor, title : Optional[str] = None):
    try:
        from torchvision.utils import make_grid
        from torchvision.transforms.v2.functional import to_pil_image
        import matplotlib.pyplot as plt
    except ImportError:
        print("Cannot plot, please run `pip install torchvision matplotlib`")
        return

    plt.rcParams["savefig.bbox"] = 'tight'
    fig, ax = plt.subplots()
    ax.imshow(to_pil_image(make_grid(frames)))
    ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    if title is not None:
        ax.set_title(title)
    plt.tight_layout()

Creating a decoder

We can now create a decoder from the raw (encoded) video bytes. You can of course use a local video file and pass the path as input, rather than download a video.

from torchcodec.decoders import SimpleVideoDecoder

# You can also pass a path to a local file!
decoder = SimpleVideoDecoder(raw_video_bytes)

The has not yet been decoded by the decoder, but we already have access to some metadata via the metadata attribute which is a VideoStreamMetadata object.

print(decoder.metadata)
VideoStreamMetadata:
  num_frames: 345
  duration_seconds: 13.8
  average_fps: 25.0
  duration_seconds_from_header: 13.8
  bit_rate: 505790.0
  num_frames_from_header: 345
  num_frames_from_content: 345
  begin_stream_seconds: 0.0
  end_stream_seconds: 13.8
  codec: h264
  width: 640
  height: 360
  average_fps_from_header: 25.0
  stream_index: 0

Decoding frames by indexing the decoder

first_frame = decoder[0]  # using a single int index
every_twenty_frame = decoder[0 : -1 : 20]  # using slices

print(f"{first_frame.shape = }")
print(f"{first_frame.dtype = }")
print(f"{every_twenty_frame.shape = }")
print(f"{every_twenty_frame.dtype = }")
first_frame.shape = torch.Size([3, 360, 640])
first_frame.dtype = torch.uint8
every_twenty_frame.shape = torch.Size([18, 3, 360, 640])
every_twenty_frame.dtype = torch.uint8

Indexing the decoder returns the frames as torch.Tensor objects. By default, the shape of the frames is (N, C, H, W) where N is the batch size C the number of channels, H is the height, and W is the width of the frames. The batch dimension N is only present when we’re decoding more than one frame. The dimension order can be changed to N, H, W, C using the dimension_order parameter of SimpleVideoDecoder. Frames are always of torch.uint8 dtype.

plot(first_frame, "First frame")
First frame
plot(every_twenty_frame, "Every 20 frame")
Every 20 frame

Iterating over frames

The decoder is a normal iterable object and can be iterated over like so:

for frame in decoder:
    assert (
        isinstance(frame, torch.Tensor)
        and frame.shape == (3, decoder.metadata.height, decoder.metadata.width)
    )

Retrieving pts and duration of frames

Indexing the decoder returns pure torch.Tensor objects. Sometimes, it can be useful to retrieve additional information about the frames, such as their pts (Presentation Time Stamp), and their duration. This can be achieved using the get_frame_at() and get_frames_at() methods, which will return a Frame and FrameBatch objects respectively.

last_frame = decoder.get_frame_at(len(decoder) - 1)
print(f"{type(last_frame) = }")
print(last_frame)
type(last_frame) = <class 'torchcodec.decoders._simple_video_decoder.Frame'>
Frame:
  data (shape): torch.Size([3, 360, 640])
  pts_seconds: 13.76
  duration_seconds: 0.04
middle_frames = decoder.get_frames_at(start=10, stop=20, step=2)
print(f"{type(middle_frames) = }")
print(middle_frames)
type(middle_frames) = <class 'torchcodec.decoders._simple_video_decoder.FrameBatch'>
FrameBatch:
  data (shape): torch.Size([5, 3, 360, 640])
  pts_seconds: tensor([0.4000, 0.4800, 0.5600, 0.6400, 0.7200], dtype=torch.float64)
  duration_seconds: tensor([0.0400, 0.0400, 0.0400, 0.0400, 0.0400], dtype=torch.float64)
plot(last_frame.data, "Last frame")
plot(middle_frames.data, "Middle frames")
  • Last frame
  • Middle frames

Both Frame and FrameBatch have a data field, which contains the decoded tensor data. They also have the pts_seconds and duration_seconds fields which are single ints for Frame, and 1-D torch.Tensor for FrameBatch (one value per frame in the batch).

Using time-based indexing

So far, we have retrieved frames based on their index. We can also retrieve frames based on when they are displayed with get_frame_displayed_at() and get_frames_displayed_at(), which also returns Frame and FrameBatch respectively.

frame_at_2_seconds = decoder.get_frame_displayed_at(seconds=2)
print(f"{type(frame_at_2_seconds) = }")
print(frame_at_2_seconds)
type(frame_at_2_seconds) = <class 'torchcodec.decoders._simple_video_decoder.Frame'>
Frame:
  data (shape): torch.Size([3, 360, 640])
  pts_seconds: 2.0
  duration_seconds: 0.04
first_two_seconds = decoder.get_frames_displayed_at(
    start_seconds=0,
    stop_seconds=2,
)
print(f"{type(first_two_seconds) = }")
print(first_two_seconds)
type(first_two_seconds) = <class 'torchcodec.decoders._simple_video_decoder.FrameBatch'>
FrameBatch:
  data (shape): torch.Size([50, 3, 360, 640])
  pts_seconds: tensor([0.0000, 0.0400, 0.0800, 0.1200, 0.1600, 0.2000, 0.2400, 0.2800, 0.3200,
        0.3600, 0.4000, 0.4400, 0.4800, 0.5200, 0.5600, 0.6000, 0.6400, 0.6800,
        0.7200, 0.7600, 0.8000, 0.8400, 0.8800, 0.9200, 0.9600, 1.0000, 1.0400,
        1.0800, 1.1200, 1.1200, 1.2000, 1.2400, 1.2800, 1.3200, 1.3600, 1.4000,
        1.4400, 1.4800, 1.5200, 1.5600, 1.6000, 1.6400, 1.6800, 1.7200, 1.7600,
        1.8000, 1.8400, 1.8800, 1.9200, 1.9600], dtype=torch.float64)
  duration_seconds: tensor([0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400,
        0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400,
        0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400,
        0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400,
        0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400,
        0.0400, 0.0400, 0.0400, 0.0400, 0.0400], dtype=torch.float64)
plot(frame_at_2_seconds.data, "Frame displayed at 2 seconds")
plot(first_two_seconds.data, "Frames displayed during [0, 2) seconds")
  • Frame displayed at 2 seconds
  • Frames displayed during [0, 2) seconds

Total running time of the script: (0 minutes 2.720 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources