.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_video_api.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_video_api.py: ======================= Video API ======================= This example illustrates some of the APIs that torchvision offers for videos, together with the examples on how to build datasets and more. .. GENERATED FROM PYTHON SOURCE LINES 11-16 1. Introduction: building a new video object and examining the properties ------------------------------------------------------------------------- First we select a video to test the object out. For the sake of argument we're using one from kinetics400 dataset. To create it, we need to define the path and the stream we want to use. .. GENERATED FROM PYTHON SOURCE LINES 18-31 Chosen video statistics: - WUzgd7C1pWA.mp4 - source: - kinetics-400 - video: - H-264 - MPEG-4 AVC (part 10) (avc1) - fps: 29.97 - audio: - MPEG AAC audio (mp4a) - sample rate: 48K Hz .. GENERATED FROM PYTHON SOURCE LINES 31-45 .. code-block:: default import torch import torchvision from torchvision.datasets.utils import download_url torchvision.set_video_backend("video_reader") # Download the sample video download_url( "https://github.com/pytorch/vision/blob/main/test/assets/videos/WUzgd7C1pWA.mp4?raw=true", ".", "WUzgd7C1pWA.mp4" ) video_path = "./WUzgd7C1pWA.mp4" .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading https://raw.githubusercontent.com/pytorch/vision/main/test/assets/videos/WUzgd7C1pWA.mp4 to ./WUzgd7C1pWA.mp4 3.7% 7.4% 11.1% 14.7% 18.4% 22.1% 25.8% 29.5% 33.2% 36.8% 40.5% 44.2% 47.9% 51.6% 55.3% 58.9% 62.6% 66.3% 70.0% 73.7% 77.4% 81.0% 84.7% 88.4% 92.1% 95.8% 99.5% 100.0% .. GENERATED FROM PYTHON SOURCE LINES 46-50 Streams are defined in a similar fashion as torch devices. We encode them as strings in a form of ``stream_type:stream_id`` where ``stream_type`` is a string and ``stream_id`` a long int. The constructor accepts passing a ``stream_type`` only, in which case the stream is auto-discovered. Firstly, let's get the metadata for our particular video: .. GENERATED FROM PYTHON SOURCE LINES 50-55 .. code-block:: default stream = "video" video = torchvision.io.VideoReader(video_path, stream) video.get_metadata() .. rst-class:: sphx-glr-script-out .. code-block:: none {'video': {'duration': [10.9109], 'fps': [29.97002997002997]}, 'audio': {'duration': [10.9], 'framerate': [48000.0]}, 'subtitles': {'duration': []}, 'cc': {'duration': []}} .. GENERATED FROM PYTHON SOURCE LINES 56-63 Here we can see that video has two streams - a video and an audio stream. Currently available stream types include ['video', 'audio']. Each descriptor consists of two parts: stream type (e.g. 'video') and a unique stream id (which are determined by video encoding). In this way, if the video container contains multiple streams of the same type, users can access the one they want. If only stream type is passed, the decoder auto-detects first stream of that type and returns it. .. GENERATED FROM PYTHON SOURCE LINES 65-72 Let's read all the frames from the video stream. By default, the return value of ``next(video_reader)`` is a dict containing the following fields. The return fields are: - ``data``: containing a torch.tensor - ``pts``: containing a float timestamp of this particular frame .. GENERATED FROM PYTHON SOURCE LINES 72-88 .. code-block:: default metadata = video.get_metadata() video.set_current_stream("audio") frames = [] # we are going to save the frames here. ptss = [] # pts is a presentation timestamp in seconds (float) of each frame for frame in video: frames.append(frame['data']) ptss.append(frame['pts']) print("PTS for first five frames ", ptss[:5]) print("Total number of frames: ", len(frames)) approx_nf = metadata['audio']['duration'][0] * metadata['audio']['framerate'][0] print("Approx total number of datapoints we can expect: ", approx_nf) print("Read data size: ", frames[0].size(0) * len(frames)) .. rst-class:: sphx-glr-script-out .. code-block:: none PTS for first five frames [0.0, 0.021332999999999998, 0.042667, 0.064, 0.08533299999999999] Total number of frames: 511 Approx total number of datapoints we can expect: 523200.0 Read data size: 523264 .. GENERATED FROM PYTHON SOURCE LINES 89-97 But what if we only want to read certain time segment of the video? That can be done easily using the combination of our ``seek`` function, and the fact that each call to next returns the presentation timestamp of the returned frame in seconds. Given that our implementation relies on python iterators, we can leverage itertools to simplify the process and make it more pythonic. For example, if we wanted to read ten frames from second second: .. GENERATED FROM PYTHON SOURCE LINES 97-110 .. code-block:: default import itertools video.set_current_stream("video") frames = [] # we are going to save the frames here. # We seek into a second second of the video and use islice to get 10 frames since for frame, pts in itertools.islice(video.seek(2), 10): frames.append(frame) print("Total number of frames: ", len(frames)) .. rst-class:: sphx-glr-script-out .. code-block:: none Total number of frames: 10 .. GENERATED FROM PYTHON SOURCE LINES 111-115 Or if we wanted to read from 2nd to 5th second, We seek into a second second of the video, then we utilize the itertools takewhile to get the correct number of frames: .. GENERATED FROM PYTHON SOURCE LINES 115-128 .. code-block:: default video.set_current_stream("video") frames = [] # we are going to save the frames here. video = video.seek(2) for frame in itertools.takewhile(lambda x: x['pts'] <= 5, video): frames.append(frame['data']) print("Total number of frames: ", len(frames)) approx_nf = (5 - 2) * video.get_metadata()['video']['fps'][0] print("We can expect approx: ", approx_nf) print("Tensor size: ", frames[0].size()) .. rst-class:: sphx-glr-script-out .. code-block:: none Total number of frames: 90 We can expect approx: 89.91008991008991 Tensor size: torch.Size([3, 256, 340]) .. GENERATED FROM PYTHON SOURCE LINES 129-133 2. Building a sample read_video function ---------------------------------------------------------------------------------------- We can utilize the methods above to build the read video function that follows the same API to the existing ``read_video`` function. .. GENERATED FROM PYTHON SOURCE LINES 133-173 .. code-block:: default def example_read_video(video_object, start=0, end=None, read_video=True, read_audio=True): if end is None: end = float("inf") if end < start: raise ValueError( "end time should be larger than start time, got " f"start time={start} and end time={end}" ) video_frames = torch.empty(0) video_pts = [] if read_video: video_object.set_current_stream("video") frames = [] for frame in itertools.takewhile(lambda x: x['pts'] <= end, video_object.seek(start)): frames.append(frame['data']) video_pts.append(frame['pts']) if len(frames) > 0: video_frames = torch.stack(frames, 0) audio_frames = torch.empty(0) audio_pts = [] if read_audio: video_object.set_current_stream("audio") frames = [] for frame in itertools.takewhile(lambda x: x['pts'] <= end, video_object.seek(start)): frames.append(frame['data']) audio_pts.append(frame['pts']) if len(frames) > 0: audio_frames = torch.cat(frames, 0) return video_frames, audio_frames, (video_pts, audio_pts), video_object.get_metadata() # Total number of frames should be 327 for video and 523264 datapoints for audio vf, af, info, meta = example_read_video(video) print(vf.size(), af.size()) .. rst-class:: sphx-glr-script-out .. code-block:: none torch.Size([327, 3, 256, 340]) torch.Size([523264, 1]) .. GENERATED FROM PYTHON SOURCE LINES 174-179 3. Building an example randomly sampled dataset (can be applied to training dataset of kinetics400) ------------------------------------------------------------------------------------------------------- Cool, so now we can use the same principle to make the sample dataset. We suggest trying out iterable dataset for this purpose. Here, we are going to build an example dataset that reads randomly selected 10 frames of video. .. GENERATED FROM PYTHON SOURCE LINES 181-182 Make sample dataset .. GENERATED FROM PYTHON SOURCE LINES 182-187 .. code-block:: default import os os.makedirs("./dataset", exist_ok=True) os.makedirs("./dataset/1", exist_ok=True) os.makedirs("./dataset/2", exist_ok=True) .. GENERATED FROM PYTHON SOURCE LINES 188-189 Download the videos .. GENERATED FROM PYTHON SOURCE LINES 189-215 .. code-block:: default from torchvision.datasets.utils import download_url download_url( "https://github.com/pytorch/vision/blob/main/test/assets/videos/WUzgd7C1pWA.mp4?raw=true", "./dataset/1", "WUzgd7C1pWA.mp4" ) download_url( "https://github.com/pytorch/vision/blob/main/test/assets/videos/RATRACE_wave_f_nm_np1_fr_goo_37.avi?raw=true", "./dataset/1", "RATRACE_wave_f_nm_np1_fr_goo_37.avi" ) download_url( "https://github.com/pytorch/vision/blob/main/test/assets/videos/SOX5yA1l24A.mp4?raw=true", "./dataset/2", "SOX5yA1l24A.mp4" ) download_url( "https://github.com/pytorch/vision/blob/main/test/assets/videos/v_SoccerJuggling_g23_c01.avi?raw=true", "./dataset/2", "v_SoccerJuggling_g23_c01.avi" ) download_url( "https://github.com/pytorch/vision/blob/main/test/assets/videos/v_SoccerJuggling_g24_c01.avi?raw=true", "./dataset/2", "v_SoccerJuggling_g24_c01.avi" ) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading https://raw.githubusercontent.com/pytorch/vision/main/test/assets/videos/WUzgd7C1pWA.mp4 to ./dataset/1/WUzgd7C1pWA.mp4 3.7% 7.4% 11.1% 14.7% 18.4% 22.1% 25.8% 29.5% 33.2% 36.8% 40.5% 44.2% 47.9% 51.6% 55.3% 58.9% 62.6% 66.3% 70.0% 73.7% 77.4% 81.0% 84.7% 88.4% 92.1% 95.8% 99.5% 100.0% Downloading https://raw.githubusercontent.com/pytorch/vision/main/test/assets/videos/RATRACE_wave_f_nm_np1_fr_goo_37.avi to ./dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi 12.4% 24.9% 37.3% 49.7% 62.1% 74.6% 87.0% 99.4% 100.0% Downloading https://raw.githubusercontent.com/pytorch/vision/main/test/assets/videos/SOX5yA1l24A.mp4 to ./dataset/2/SOX5yA1l24A.mp4 5.8% 11.7% 17.5% 23.4% 29.2% 35.1% 40.9% 46.8% 52.6% 58.5% 64.3% 70.2% 76.0% 81.9% 87.7% 93.6% 99.4% 100.0% Downloading https://raw.githubusercontent.com/pytorch/vision/main/test/assets/videos/v_SoccerJuggling_g23_c01.avi to ./dataset/2/v_SoccerJuggling_g23_c01.avi 6.4% 12.9% 19.3% 25.8% 32.2% 38.7% 45.1% 51.6% 58.0% 64.5% 70.9% 77.3% 83.8% 90.2% 96.7% 100.0% Downloading https://raw.githubusercontent.com/pytorch/vision/main/test/assets/videos/v_SoccerJuggling_g24_c01.avi to ./dataset/2/v_SoccerJuggling_g24_c01.avi 5.3% 10.5% 15.8% 21.0% 26.3% 31.6% 36.8% 42.1% 47.3% 52.6% 57.9% 63.1% 68.4% 73.6% 78.9% 84.2% 89.4% 94.7% 99.9% 100.0% .. GENERATED FROM PYTHON SOURCE LINES 216-217 Housekeeping and utilities .. GENERATED FROM PYTHON SOURCE LINES 217-235 .. code-block:: default import os import random from torchvision.datasets.folder import make_dataset from torchvision import transforms as t def _find_classes(dir): classes = [d.name for d in os.scandir(dir) if d.is_dir()] classes.sort() class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} return classes, class_to_idx def get_samples(root, extensions=(".mp4", ".avi")): _, class_to_idx = _find_classes(root) return make_dataset(root, class_to_idx, extensions=extensions) .. GENERATED FROM PYTHON SOURCE LINES 236-246 We are going to define the dataset and some basic arguments. We assume the structure of the FolderDataset, and add the following parameters: - ``clip_len``: length of a clip in frames - ``frame_transform``: transform for every frame individually - ``video_transform``: transform on a video sequence .. note:: We actually add epoch size as using :func:`~torch.utils.data.IterableDataset` class allows us to naturally oversample clips or images from each video if needed. .. GENERATED FROM PYTHON SOURCE LINES 246-290 .. code-block:: default class RandomDataset(torch.utils.data.IterableDataset): def __init__(self, root, epoch_size=None, frame_transform=None, video_transform=None, clip_len=16): super(RandomDataset).__init__() self.samples = get_samples(root) # Allow for temporal jittering if epoch_size is None: epoch_size = len(self.samples) self.epoch_size = epoch_size self.clip_len = clip_len self.frame_transform = frame_transform self.video_transform = video_transform def __iter__(self): for i in range(self.epoch_size): # Get random sample path, target = random.choice(self.samples) # Get video object vid = torchvision.io.VideoReader(path, "video") metadata = vid.get_metadata() video_frames = [] # video frame buffer # Seek and return frames max_seek = metadata["video"]['duration'][0] - (self.clip_len / metadata["video"]['fps'][0]) start = random.uniform(0., max_seek) for frame in itertools.islice(vid.seek(start), self.clip_len): video_frames.append(self.frame_transform(frame['data'])) current_pts = frame['pts'] # Stack it into a tensor video = torch.stack(video_frames, 0) if self.video_transform: video = self.video_transform(video) output = { 'path': path, 'video': video, 'target': target, 'start': start, 'end': current_pts} yield output .. GENERATED FROM PYTHON SOURCE LINES 291-305 Given a path of videos in a folder structure, i.e: - dataset - class 1 - file 0 - file 1 - ... - class 2 - file 0 - file 1 - ... - ... We can generate a dataloader and test the dataset. .. GENERATED FROM PYTHON SOURCE LINES 305-312 .. code-block:: default transforms = [t.Resize((112, 112))] frame_transform = t.Compose(transforms) dataset = RandomDataset("./dataset", epoch_size=None, frame_transform=frame_transform) .. GENERATED FROM PYTHON SOURCE LINES 313-324 .. code-block:: default from torch.utils.data import DataLoader loader = DataLoader(dataset, batch_size=12) data = {"video": [], 'start': [], 'end': [], 'tensorsize': []} for batch in loader: for i in range(len(batch['path'])): data['video'].append(batch['path'][i]) data['start'].append(batch['start'][i].item()) data['end'].append(batch['end'][i].item()) data['tensorsize'].append(batch['video'][i].size()) print(data) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/project/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True). warnings.warn( {'video': ['./dataset/2/v_SoccerJuggling_g23_c01.avi', './dataset/1/WUzgd7C1pWA.mp4', './dataset/2/SOX5yA1l24A.mp4', './dataset/2/v_SoccerJuggling_g23_c01.avi', './dataset/1/WUzgd7C1pWA.mp4'], 'start': [7.277870922485677, 3.5101824393661527, 3.533753864502439, 5.864848793058175, 0.806605869755374], 'end': [7.807799999999999, 4.037367, 4.037367, 6.3730329999999995, 1.334667], 'tensorsize': [torch.Size([16, 3, 112, 112]), torch.Size([16, 3, 112, 112]), torch.Size([16, 3, 112, 112]), torch.Size([16, 3, 112, 112]), torch.Size([16, 3, 112, 112])]} .. GENERATED FROM PYTHON SOURCE LINES 325-328 4. Data Visualization ---------------------------------- Example of visualized video .. GENERATED FROM PYTHON SOURCE LINES 328-337 .. code-block:: default import matplotlib.pyplot as plt plt.figure(figsize=(12, 12)) for i in range(16): plt.subplot(4, 4, i + 1) plt.imshow(batch["video"][0, i, ...].permute(1, 2, 0)) plt.axis("off") .. image-sg:: /auto_examples/images/sphx_glr_plot_video_api_001.png :alt: plot video api :srcset: /auto_examples/images/sphx_glr_plot_video_api_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 338-339 Cleanup the video and dataset: .. GENERATED FROM PYTHON SOURCE LINES 339-343 .. code-block:: default import os import shutil os.remove("./WUzgd7C1pWA.mp4") shutil.rmtree("./dataset") .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 3.843 seconds) .. _sphx_glr_download_auto_examples_plot_video_api.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_video_api.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_video_api.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_