Shortcuts

Source code for torchvision.io.video

import gc
import math
import re
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch

from . import _video_opt


try:
    import av

    av.logging.set_level(av.logging.ERROR)
    if not hasattr(av.video.frame.VideoFrame, "pict_type"):
        av = ImportError(
            """\
Your version of PyAV is too old for the necessary video operations in torchvision.
If you are on Python 3.5, you will have to build from source (the conda-forge
packages are not up-to-date).  See
https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
"""
        )
except ImportError:
    av = ImportError(
        """\
PyAV is not installed, and is necessary for the video operations in torchvision.
See https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
"""
    )


def _check_av_available() -> None:
    if isinstance(av, Exception):
        raise av


def _av_available() -> bool:
    return not isinstance(av, Exception)


# PyAV has some reference cycles
_CALLED_TIMES = 0
_GC_COLLECTION_INTERVAL = 10


[docs]def write_video( filename: str, video_array: torch.Tensor, fps: float, video_codec: str = "libx264", options: Optional[Dict[str, Any]] = None, audio_array: Optional[torch.Tensor] = None, audio_fps: Optional[float] = None, audio_codec: Optional[str] = None, audio_options: Optional[Dict[str, Any]] = None, ) -> None: """ Writes a 4d tensor in [T, H, W, C] format in a video file Args: filename (str): path where the video will be saved video_array (Tensor[T, H, W, C]): tensor containing the individual frames, as a uint8 tensor in [T, H, W, C] format fps (Number): video frames per second video_codec (str): the name of the video codec, i.e. "libx264", "h264", etc. options (Dict): dictionary containing options to be passed into the PyAV video stream audio_array (Tensor[C, N]): tensor containing the audio, where C is the number of channels and N is the number of samples audio_fps (Number): audio sample rate, typically 44100 or 48000 audio_codec (str): the name of the audio codec, i.e. "mp3", "aac", etc. audio_options (Dict): dictionary containing options to be passed into the PyAV audio stream """ _check_av_available() video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy() # PyAV does not support floating point numbers with decimal point # and will throw OverflowException in case this is not the case if isinstance(fps, float): fps = np.round(fps) with av.open(filename, mode="w") as container: stream = container.add_stream(video_codec, rate=fps) stream.width = video_array.shape[2] stream.height = video_array.shape[1] stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24" stream.options = options or {} if audio_array is not None: audio_format_dtypes = { 'dbl': '<f8', 'dblp': '<f8', 'flt': '<f4', 'fltp': '<f4', 's16': '<i2', 's16p': '<i2', 's32': '<i4', 's32p': '<i4', 'u8': 'u1', 'u8p': 'u1', } a_stream = container.add_stream(audio_codec, rate=audio_fps) a_stream.options = audio_options or {} num_channels = audio_array.shape[0] audio_layout = "stereo" if num_channels > 1 else "mono" audio_sample_fmt = container.streams.audio[0].format.name format_dtype = np.dtype(audio_format_dtypes[audio_sample_fmt]) audio_array = torch.as_tensor(audio_array).numpy().astype(format_dtype) frame = av.AudioFrame.from_ndarray( audio_array, format=audio_sample_fmt, layout=audio_layout ) frame.sample_rate = audio_fps for packet in a_stream.encode(frame): container.mux(packet) for packet in a_stream.encode(): container.mux(packet) for img in video_array: frame = av.VideoFrame.from_ndarray(img, format="rgb24") frame.pict_type = "NONE" for packet in stream.encode(frame): container.mux(packet) # Flush stream for packet in stream.encode(): container.mux(packet)
def _read_from_stream( container: "av.container.Container", start_offset: float, end_offset: float, pts_unit: str, stream: "av.stream.Stream", stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]], ) -> List["av.frame.Frame"]: global _CALLED_TIMES, _GC_COLLECTION_INTERVAL _CALLED_TIMES += 1 if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1: gc.collect() if pts_unit == "sec": start_offset = int(math.floor(start_offset * (1 / stream.time_base))) if end_offset != float("inf"): end_offset = int(math.ceil(end_offset * (1 / stream.time_base))) else: warnings.warn( "The pts_unit 'pts' gives wrong results and will be removed in a " + "follow-up version. Please use pts_unit 'sec'." ) frames = {} should_buffer = True max_buffer_size = 5 if stream.type == "video": # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt) # so need to buffer some extra frames to sort everything # properly extradata = stream.codec_context.extradata # overly complicated way of finding if `divx_packed` is set, following # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263 if extradata and b"DivX" in extradata: # can't use regex directly because of some weird characters sometimes... pos = extradata.find(b"DivX") d = extradata[pos:] o = re.search(br"DivX(\d+)Build(\d+)(\w)", d) if o is None: o = re.search(br"DivX(\d+)b(\d+)(\w)", d) if o is not None: should_buffer = o.group(3) == b"p" seek_offset = start_offset # some files don't seek to the right location, so better be safe here seek_offset = max(seek_offset - 1, 0) if should_buffer: # FIXME this is kind of a hack, but we will jump to the previous keyframe # so this will be safe seek_offset = max(seek_offset - max_buffer_size, 0) try: # TODO check if stream needs to always be the video stream here or not container.seek(seek_offset, any_frame=False, backward=True, stream=stream) except av.AVError: # TODO add some warnings in this case # print("Corrupted file?", container.name) return [] buffer_count = 0 try: for _idx, frame in enumerate(container.decode(**stream_name)): frames[frame.pts] = frame if frame.pts >= end_offset: if should_buffer and buffer_count < max_buffer_size: buffer_count += 1 continue break except av.AVError: # TODO add a warning pass # ensure that the results are sorted wrt the pts result = [ frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset ] if len(frames) > 0 and start_offset > 0 and start_offset not in frames: # if there is no frame that exactly matches the pts of start_offset # add the last frame smaller than start_offset, to guarantee that # we will have all the necessary data. This is most useful for audio preceding_frames = [i for i in frames if i < start_offset] if len(preceding_frames) > 0: first_frame_pts = max(preceding_frames) result.insert(0, frames[first_frame_pts]) return result def _align_audio_frames( aframes: torch.Tensor, audio_frames: List["av.frame.Frame"], ref_start: int, ref_end: float ) -> torch.Tensor: start, end = audio_frames[0].pts, audio_frames[-1].pts total_aframes = aframes.shape[1] step_per_aframe = (end - start + 1) / total_aframes s_idx = 0 e_idx = total_aframes if start < ref_start: s_idx = int((ref_start - start) / step_per_aframe) if end > ref_end: e_idx = int((ref_end - end) / step_per_aframe) return aframes[:, s_idx:e_idx]
[docs]def read_video( filename: str, start_pts: int = 0, end_pts: Optional[float] = None, pts_unit: str = "pts" ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: """ Reads a video from a file, returning both the video frames as well as the audio frames Args: filename (str): path to the video file start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): The start presentation time of the video end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): The end presentation time pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted, either 'pts' or 'sec'. Defaults to 'pts'. Returns: vframes (Tensor[T, H, W, C]): the `T` video frames aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int) """ from torchvision import get_video_backend if get_video_backend() != "pyav": return _video_opt._read_video(filename, start_pts, end_pts, pts_unit) _check_av_available() if end_pts is None: end_pts = float("inf") if end_pts < start_pts: raise ValueError( "end_pts should be larger than start_pts, got " "start_pts={} and end_pts={}".format(start_pts, end_pts) ) info = {} video_frames = [] audio_frames = [] try: with av.open(filename, metadata_errors="ignore") as container: time_base = _video_opt.default_timebase if container.streams.video: time_base = container.streams.video[0].time_base elif container.streams.audio: time_base = container.streams.audio[0].time_base # video_timebase is the default time_base start_pts_sec, end_pts_sec, pts_unit = _video_opt._convert_to_sec( start_pts, end_pts, pts_unit, time_base) if container.streams.video: video_frames = _read_from_stream( container, start_pts_sec, end_pts_sec, pts_unit, container.streams.video[0], {"video": 0}, ) video_fps = container.streams.video[0].average_rate # guard against potentially corrupted files if video_fps is not None: info["video_fps"] = float(video_fps) if container.streams.audio: audio_frames = _read_from_stream( container, start_pts_sec, end_pts_sec, pts_unit, container.streams.audio[0], {"audio": 0}, ) info["audio_fps"] = container.streams.audio[0].rate except av.AVError: # TODO raise a warning? pass vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames] aframes_list = [frame.to_ndarray() for frame in audio_frames] if vframes_list: vframes = torch.as_tensor(np.stack(vframes_list)) else: vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) if aframes_list: aframes = np.concatenate(aframes_list, 1) aframes = torch.as_tensor(aframes) aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts) else: aframes = torch.empty((1, 0), dtype=torch.float32) return vframes, aframes, info
def _can_read_timestamps_from_packets(container: "av.container.Container") -> bool: extradata = container.streams[0].codec_context.extradata if extradata is None: return False if b"Lavc" in extradata: return True return False def _decode_video_timestamps(container: "av.container.Container") -> List[int]: if _can_read_timestamps_from_packets(container): # fast path return [x.pts for x in container.demux(video=0) if x.pts is not None] else: return [x.pts for x in container.decode(video=0) if x.pts is not None]
[docs]def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[int], Optional[float]]: """ List the video frames timestamps. Note that the function decodes the whole video frame-by-frame. Args: filename (str): path to the video file pts_unit (str, optional): unit in which timestamp values will be returned either 'pts' or 'sec'. Defaults to 'pts'. Returns: pts (List[int] if pts_unit = 'pts', List[Fraction] if pts_unit = 'sec'): presentation timestamps for each one of the frames in the video. video_fps (float, optional): the frame rate for the video """ from torchvision import get_video_backend if get_video_backend() != "pyav": return _video_opt._read_video_timestamps(filename, pts_unit) _check_av_available() video_fps = None pts = [] try: with av.open(filename, metadata_errors="ignore") as container: if container.streams.video: video_stream = container.streams.video[0] video_time_base = video_stream.time_base try: pts = _decode_video_timestamps(container) except av.AVError: warnings.warn(f"Failed decoding frames for file {filename}") video_fps = float(video_stream.average_rate) except av.AVError: # TODO add a warning pass pts.sort() if pts_unit == "sec": pts = [x * video_time_base for x in pts] return pts, video_fps

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources