Decoding a video with SimpleVideoDecoder¶
In this example, we’ll learn how to decode a video using the
SimpleVideoDecoder
class.
First, a bit of boilerplate: we’ll download a video from the web, and define a plotting utility. You can ignore that part and jump right below to Creating a decoder.
from typing import Optional
import torch
import requests
# Video source: https://www.pexels.com/video/dog-eating-854132/
# License: CC0. Author: Coverr.
url = "https://videos.pexels.com/video-files/854132/854132-sd_640_360_25fps.mp4"
response = requests.get(url)
if response.status_code != 200:
raise RuntimeError(f"Failed to download video. {response.status_code = }.")
raw_video_bytes = response.content
def plot(frames: torch.Tensor, title : Optional[str] = None):
try:
from torchvision.utils import make_grid
from torchvision.transforms.v2.functional import to_pil_image
import matplotlib.pyplot as plt
except ImportError:
print("Cannot plot, please run `pip install torchvision matplotlib`")
return
plt.rcParams["savefig.bbox"] = 'tight'
fig, ax = plt.subplots()
ax.imshow(to_pil_image(make_grid(frames)))
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if title is not None:
ax.set_title(title)
plt.tight_layout()
Creating a decoder¶
We can now create a decoder from the raw (encoded) video bytes. You can of course use a local video file and pass the path as input, rather than download a video.
from torchcodec.decoders import SimpleVideoDecoder
# You can also pass a path to a local file!
decoder = SimpleVideoDecoder(raw_video_bytes)
The has not yet been decoded by the decoder, but we already have access to
some metadata via the metadata
attribute which is a
VideoStreamMetadata
object.
print(decoder.metadata)
VideoStreamMetadata:
num_frames: 345
duration_seconds: 13.8
average_fps: 25.0
duration_seconds_from_header: 13.8
bit_rate: 505790.0
num_frames_from_header: 345
num_frames_from_content: 345
begin_stream_seconds: 0.0
end_stream_seconds: 13.8
codec: h264
width: 640
height: 360
average_fps_from_header: 25.0
stream_index: 0
Decoding frames by indexing the decoder¶
first_frame = decoder[0] # using a single int index
every_twenty_frame = decoder[0 : -1 : 20] # using slices
print(f"{first_frame.shape = }")
print(f"{first_frame.dtype = }")
print(f"{every_twenty_frame.shape = }")
print(f"{every_twenty_frame.dtype = }")
first_frame.shape = torch.Size([3, 360, 640])
first_frame.dtype = torch.uint8
every_twenty_frame.shape = torch.Size([18, 3, 360, 640])
every_twenty_frame.dtype = torch.uint8
Indexing the decoder returns the frames as torch.Tensor
objects.
By default, the shape of the frames is (N, C, H, W)
where N is the batch
size C the number of channels, H is the height, and W is the width of the
frames. The batch dimension N is only present when we’re decoding more than
one frame. The dimension order can be changed to N, H, W, C
using the
dimension_order
parameter of
SimpleVideoDecoder
. Frames are always of
torch.uint8
dtype.
plot(first_frame, "First frame")
plot(every_twenty_frame, "Every 20 frame")
Iterating over frames¶
The decoder is a normal iterable object and can be iterated over like so:
for frame in decoder:
assert (
isinstance(frame, torch.Tensor)
and frame.shape == (3, decoder.metadata.height, decoder.metadata.width)
)
Retrieving pts and duration of frames¶
Indexing the decoder returns pure torch.Tensor
objects. Sometimes, it
can be useful to retrieve additional information about the frames, such as
their pts (Presentation Time Stamp), and their duration.
This can be achieved using the
get_frame_at()
and
get_frames_at()
methods, which
will return a Frame
and
FrameBatch
objects respectively.
last_frame = decoder.get_frame_at(len(decoder) - 1)
print(f"{type(last_frame) = }")
print(last_frame)
type(last_frame) = <class 'torchcodec.decoders._simple_video_decoder.Frame'>
Frame:
data (shape): torch.Size([3, 360, 640])
pts_seconds: 13.76
duration_seconds: 0.04
middle_frames = decoder.get_frames_at(start=10, stop=20, step=2)
print(f"{type(middle_frames) = }")
print(middle_frames)
type(middle_frames) = <class 'torchcodec.decoders._simple_video_decoder.FrameBatch'>
FrameBatch:
data (shape): torch.Size([5, 3, 360, 640])
pts_seconds: tensor([0.4000, 0.4800, 0.5600, 0.6400, 0.7200], dtype=torch.float64)
duration_seconds: tensor([0.0400, 0.0400, 0.0400, 0.0400, 0.0400], dtype=torch.float64)
plot(last_frame.data, "Last frame")
plot(middle_frames.data, "Middle frames")
Both Frame
and
FrameBatch
have a data
field, which contains
the decoded tensor data. They also have the pts_seconds
and
duration_seconds
fields which are single ints for
Frame
, and 1-D torch.Tensor
for
FrameBatch
(one value per frame in the batch).
Using time-based indexing¶
So far, we have retrieved frames based on their index. We can also retrieve
frames based on when they are displayed with
get_frame_displayed_at()
and
get_frames_displayed_at()
, which
also returns Frame
and FrameBatch
respectively.
frame_at_2_seconds = decoder.get_frame_displayed_at(seconds=2)
print(f"{type(frame_at_2_seconds) = }")
print(frame_at_2_seconds)
type(frame_at_2_seconds) = <class 'torchcodec.decoders._simple_video_decoder.Frame'>
Frame:
data (shape): torch.Size([3, 360, 640])
pts_seconds: 2.0
duration_seconds: 0.04
first_two_seconds = decoder.get_frames_displayed_at(
start_seconds=0,
stop_seconds=2,
)
print(f"{type(first_two_seconds) = }")
print(first_two_seconds)
type(first_two_seconds) = <class 'torchcodec.decoders._simple_video_decoder.FrameBatch'>
FrameBatch:
data (shape): torch.Size([50, 3, 360, 640])
pts_seconds: tensor([0.0000, 0.0400, 0.0800, 0.1200, 0.1600, 0.2000, 0.2400, 0.2800, 0.3200,
0.3600, 0.4000, 0.4400, 0.4800, 0.5200, 0.5600, 0.6000, 0.6400, 0.6800,
0.7200, 0.7600, 0.8000, 0.8400, 0.8800, 0.9200, 0.9600, 1.0000, 1.0400,
1.0800, 1.1200, 1.1200, 1.2000, 1.2400, 1.2800, 1.3200, 1.3600, 1.4000,
1.4400, 1.4800, 1.5200, 1.5600, 1.6000, 1.6400, 1.6800, 1.7200, 1.7600,
1.8000, 1.8400, 1.8800, 1.9200, 1.9600], dtype=torch.float64)
duration_seconds: tensor([0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400,
0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400,
0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400,
0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400,
0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400,
0.0400, 0.0400, 0.0400, 0.0400, 0.0400], dtype=torch.float64)
plot(frame_at_2_seconds.data, "Frame displayed at 2 seconds")
plot(first_two_seconds.data, "Frames displayed during [0, 2) seconds")
Total running time of the script: (0 minutes 2.734 seconds)