Source code for

import gc
import math
import re
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch

from . import _video_opt
from ._video_opt import VideoMetaData

    import av

    if not hasattr(, "pict_type"):
        av = ImportError(
Your version of PyAV is too old for the necessary video operations in torchvision.
If you are on Python 3.5, you will have to build from source (the conda-forge
packages are not up-to-date).  See for instructions on how to
install PyAV on your system.
except ImportError:
    av = ImportError(
PyAV is not installed, and is necessary for the video operations in torchvision.
See for instructions on how to
install PyAV on your system.

def _check_av_available() -> None:
    if isinstance(av, Exception):
        raise av

def _av_available() -> bool:
    return not isinstance(av, Exception)

# PyAV has some reference cycles

[docs]def write_video( filename: str, video_array: torch.Tensor, fps: float, video_codec: str = "libx264", options: Optional[Dict[str, Any]] = None, ) -> None: """ Writes a 4d tensor in [T, H, W, C] format in a video file Parameters ---------- filename : str path where the video will be saved video_array : Tensor[T, H, W, C] tensor containing the individual frames, as a uint8 tensor in [T, H, W, C] format fps : Number frames per second """ _check_av_available() video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy() # PyAV does not support floating point numbers with decimal point # and will throw OverflowException in case this is not the case if isinstance(fps, float): fps = np.round(fps) with, mode="w") as container: stream = container.add_stream(video_codec, rate=fps) stream.width = video_array.shape[2] stream.height = video_array.shape[1] stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24" stream.options = options or {} for img in video_array: frame = av.VideoFrame.from_ndarray(img, format="rgb24") frame.pict_type = "NONE" for packet in stream.encode(frame): container.mux(packet) # Flush stream for packet in stream.encode(): container.mux(packet)
def _read_from_stream( container: "av.container.Container", start_offset: float, end_offset: float, pts_unit: str, stream: "", stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]], ) -> List["av.frame.Frame"]: global _CALLED_TIMES, _GC_COLLECTION_INTERVAL _CALLED_TIMES += 1 if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1: gc.collect() if pts_unit == "sec": start_offset = int(math.floor(start_offset * (1 / stream.time_base))) if end_offset != float("inf"): end_offset = int(math.ceil(end_offset * (1 / stream.time_base))) else: warnings.warn( "The pts_unit 'pts' gives wrong results and will be removed in a " + "follow-up version. Please use pts_unit 'sec'." ) frames = {} should_buffer = True max_buffer_size = 5 if stream.type == "video": # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt) # so need to buffer some extra frames to sort everything # properly extradata = stream.codec_context.extradata # overly complicated way of finding if `divx_packed` is set, following # if extradata and b"DivX" in extradata: # can't use regex directly because of some weird characters sometimes... pos = extradata.find(b"DivX") d = extradata[pos:] o ="DivX(\d+)Build(\d+)(\w)", d) if o is None: o ="DivX(\d+)b(\d+)(\w)", d) if o is not None: should_buffer = == b"p" seek_offset = start_offset # some files don't seek to the right location, so better be safe here seek_offset = max(seek_offset - 1, 0) if should_buffer: # FIXME this is kind of a hack, but we will jump to the previous keyframe # so this will be safe seek_offset = max(seek_offset - max_buffer_size, 0) try: # TODO check if stream needs to always be the video stream here or not, any_frame=False, backward=True, stream=stream) except av.AVError: # TODO add some warnings in this case # print("Corrupted file?", return [] buffer_count = 0 try: for _idx, frame in enumerate(container.decode(**stream_name)): frames[frame.pts] = frame if frame.pts >= end_offset: if should_buffer and buffer_count < max_buffer_size: buffer_count += 1 continue break except av.AVError: # TODO add a warning pass # ensure that the results are sorted wrt the pts result = [ frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset ] if len(frames) > 0 and start_offset > 0 and start_offset not in frames: # if there is no frame that exactly matches the pts of start_offset # add the last frame smaller than start_offset, to guarantee that # we will have all the necessary data. This is most useful for audio preceding_frames = [i for i in frames if i < start_offset] if len(preceding_frames) > 0: first_frame_pts = max(preceding_frames) result.insert(0, frames[first_frame_pts]) return result def _align_audio_frames( aframes: torch.Tensor, audio_frames: List["av.frame.Frame"], ref_start: int, ref_end: float ) -> torch.Tensor: start, end = audio_frames[0].pts, audio_frames[-1].pts total_aframes = aframes.shape[1] step_per_aframe = (end - start + 1) / total_aframes s_idx = 0 e_idx = total_aframes if start < ref_start: s_idx = int((ref_start - start) / step_per_aframe) if end > ref_end: e_idx = int((ref_end - end) / step_per_aframe) return aframes[:, s_idx:e_idx]
[docs]def read_video( filename: str, start_pts: int = 0, end_pts: Optional[float] = None, pts_unit: str = "pts" ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: """ Reads a video from a file, returning both the video frames as well as the audio frames Parameters ---------- filename : str path to the video file start_pts : int if pts_unit = 'pts', optional float / Fraction if pts_unit = 'sec', optional the start presentation time of the video end_pts : int if pts_unit = 'pts', optional float / Fraction if pts_unit = 'sec', optional the end presentation time pts_unit : str, optional unit in which start_pts and end_pts values will be interpreted, either 'pts' or 'sec'. Defaults to 'pts'. Returns ------- vframes : Tensor[T, H, W, C] the `T` video frames aframes : Tensor[K, L] the audio frames, where `K` is the number of channels and `L` is the number of points info : Dict metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int) """ from torchvision import get_video_backend if get_video_backend() != "pyav": return _video_opt._read_video(filename, start_pts, end_pts, pts_unit) _check_av_available() if end_pts is None: end_pts = float("inf") if end_pts < start_pts: raise ValueError( "end_pts should be larger than start_pts, got " "start_pts={} and end_pts={}".format(start_pts, end_pts) ) info = {} video_frames = [] audio_frames = [] try: with, metadata_errors="ignore") as container: if video_frames = _read_from_stream( container, start_pts, end_pts, pts_unit,[0], {"video": 0}, ) video_fps =[0].average_rate # guard against potentially corrupted files if video_fps is not None: info["video_fps"] = float(video_fps) if audio_frames = _read_from_stream( container, start_pts, end_pts, pts_unit,[0], {"audio": 0}, ) info["audio_fps"] =[0].rate except av.AVError: # TODO raise a warning? pass vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames] aframes_list = [frame.to_ndarray() for frame in audio_frames] if vframes_list: vframes = torch.as_tensor(np.stack(vframes_list)) else: vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) if aframes_list: aframes = np.concatenate(aframes_list, 1) aframes = torch.as_tensor(aframes) aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts) else: aframes = torch.empty((1, 0), dtype=torch.float32) return vframes, aframes, info
def _can_read_timestamps_from_packets(container: "av.container.Container") -> bool: extradata = container.streams[0].codec_context.extradata if extradata is None: return False if b"Lavc" in extradata: return True return False def _decode_video_timestamps(container: "av.container.Container") -> List[int]: if _can_read_timestamps_from_packets(container): # fast path return [x.pts for x in container.demux(video=0) if x.pts is not None] else: return [x.pts for x in container.decode(video=0) if x.pts is not None]
[docs]def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[int], Optional[float]]: """ List the video frames timestamps. Note that the function decodes the whole video frame-by-frame. Parameters ---------- filename : str path to the video file pts_unit : str, optional unit in which timestamp values will be returned either 'pts' or 'sec'. Defaults to 'pts'. Returns ------- pts : List[int] if pts_unit = 'pts' List[Fraction] if pts_unit = 'sec' presentation timestamps for each one of the frames in the video. video_fps : float, optional the frame rate for the video """ from torchvision import get_video_backend if get_video_backend() != "pyav": return _video_opt._read_video_timestamps(filename, pts_unit) _check_av_available() video_fps = None pts = [] try: with, metadata_errors="ignore") as container: if video_stream =[0] video_time_base = video_stream.time_base try: pts = _decode_video_timestamps(container) except av.AVError: warnings.warn(f"Failed decoding frames for file {filename}") video_fps = float(video_stream.average_rate) except av.AVError: # TODO add a warning pass pts.sort() if pts_unit == "sec": pts = [x * video_time_base for x in pts] return pts, video_fps


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources