.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/others/plot_video_api.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_others_plot_video_api.py:


=========
Video API
=========

.. note::
    Try on `Colab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_video_api.ipynb>`_
    or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_video_api.py>` to download the full example code.

This example illustrates some of the APIs that torchvision offers for
videos, together with the examples on how to build datasets and more.

.. GENERATED FROM PYTHON SOURCE LINES 15-20

1. Introduction: building a new video object and examining the properties
-------------------------------------------------------------------------
First we select a video to test the object out. For the sake of argument
we're using one from kinetics400 dataset.
To create it, we need to define the path and the stream we want to use.

.. GENERATED FROM PYTHON SOURCE LINES 22-35

Chosen video statistics:

- WUzgd7C1pWA.mp4
    - source:
        - kinetics-400
    - video:
        - H-264
        - MPEG-4 AVC (part 10) (avc1)
        - fps: 29.97
    - audio:
        - MPEG AAC audio (mp4a)
        - sample rate: 48K Hz


.. GENERATED FROM PYTHON SOURCE LINES 35-49

.. code-block:: Python


    import torch
    import torchvision
    from torchvision.datasets.utils import download_url
    torchvision.set_video_backend("video_reader")

    # Download the sample video
    download_url(
        "https://github.com/pytorch/vision/blob/main/test/assets/videos/WUzgd7C1pWA.mp4?raw=true",
        ".",
        "WUzgd7C1pWA.mp4"
    )
    video_path = "./WUzgd7C1pWA.mp4"





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Downloading https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/test/assets/videos/WUzgd7C1pWA.mp4 to ./WUzgd7C1pWA.mp4

    3.7%
    7.4%
    11.1%
    14.7%
    18.4%
    22.1%
    25.8%
    29.5%
    33.2%
    36.8%
    40.5%
    44.2%
    47.9%
    51.6%
    55.3%
    58.9%
    62.6%
    66.3%
    70.0%
    73.7%
    77.4%
    81.0%
    84.7%
    88.4%
    92.1%
    95.8%
    99.5%
    100.0%




.. GENERATED FROM PYTHON SOURCE LINES 50-54

Streams are defined in a similar fashion as torch devices. We encode them as strings in a form
of ``stream_type:stream_id`` where ``stream_type`` is a string and ``stream_id`` a long int.
The constructor accepts passing a ``stream_type`` only, in which case the stream is auto-discovered.
Firstly, let's get the metadata for our particular video:

.. GENERATED FROM PYTHON SOURCE LINES 54-59

.. code-block:: Python


    stream = "video"
    video = torchvision.io.VideoReader(video_path, stream)
    video.get_metadata()





.. rst-class:: sphx-glr-script-out

 .. code-block:: none


    {'video': {'duration': [10.9109], 'fps': [29.97002997002997]}, 'audio': {'duration': [10.9], 'framerate': [48000.0]}, 'subtitles': {'duration': []}, 'cc': {'duration': []}}



.. GENERATED FROM PYTHON SOURCE LINES 60-67

Here we can see that video has two streams - a video and an audio stream.
Currently available stream types include ['video', 'audio'].
Each descriptor consists of two parts: stream type (e.g. 'video') and a unique stream id
(which are determined by video encoding).
In this way, if the video container contains multiple streams of the same type,
users can access the one they want.
If only stream type is passed, the decoder auto-detects first stream of that type and returns it.

.. GENERATED FROM PYTHON SOURCE LINES 69-76

Let's read all the frames from the video stream. By default, the return value of
``next(video_reader)`` is a dict containing the following fields.

The return fields are:

- ``data``: containing a torch.tensor
- ``pts``: containing a float timestamp of this particular frame

.. GENERATED FROM PYTHON SOURCE LINES 76-92

.. code-block:: Python


    metadata = video.get_metadata()
    video.set_current_stream("audio")

    frames = []  # we are going to save the frames here.
    ptss = []  # pts is a presentation timestamp in seconds (float) of each frame
    for frame in video:
        frames.append(frame['data'])
        ptss.append(frame['pts'])

    print("PTS for first five frames ", ptss[:5])
    print("Total number of frames: ", len(frames))
    approx_nf = metadata['audio']['duration'][0] * metadata['audio']['framerate'][0]
    print("Approx total number of datapoints we can expect: ", approx_nf)
    print("Read data size: ", frames[0].size(0) * len(frames))





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    PTS for first five frames  [0.0, 0.021332999999999998, 0.042667, 0.064, 0.08533299999999999]
    Total number of frames:  511
    Approx total number of datapoints we can expect:  523200.0
    Read data size:  523264




.. GENERATED FROM PYTHON SOURCE LINES 93-101

But what if we only want to read certain time segment of the video?
That can be done easily using the combination of our ``seek`` function, and the fact that each call
to next returns the presentation timestamp of the returned frame in seconds.

Given that our implementation relies on python iterators,
we can leverage itertools to simplify the process and make it more pythonic.

For example, if we wanted to read ten frames from second second:

.. GENERATED FROM PYTHON SOURCE LINES 101-114

.. code-block:: Python



    import itertools
    video.set_current_stream("video")

    frames = []  # we are going to save the frames here.

    # We seek into a second second of the video and use islice to get 10 frames since
    for frame, pts in itertools.islice(video.seek(2), 10):
        frames.append(frame)

    print("Total number of frames: ", len(frames))





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Total number of frames:  10




.. GENERATED FROM PYTHON SOURCE LINES 115-119

Or if we wanted to read from 2nd to 5th second,
We seek into a second second of the video,
then we utilize the itertools takewhile to get the
correct number of frames:

.. GENERATED FROM PYTHON SOURCE LINES 119-132

.. code-block:: Python


    video.set_current_stream("video")
    frames = []  # we are going to save the frames here.
    video = video.seek(2)

    for frame in itertools.takewhile(lambda x: x['pts'] <= 5, video):
        frames.append(frame['data'])

    print("Total number of frames: ", len(frames))
    approx_nf = (5 - 2) * video.get_metadata()['video']['fps'][0]
    print("We can expect approx: ", approx_nf)
    print("Tensor size: ", frames[0].size())





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Total number of frames:  90
    We can expect approx:  89.91008991008991
    Tensor size:  torch.Size([3, 256, 340])




.. GENERATED FROM PYTHON SOURCE LINES 133-137

2. Building a sample read_video function
----------------------------------------------------------------------------------------
We can utilize the methods above to build the read video function that follows
the same API to the existing ``read_video`` function.

.. GENERATED FROM PYTHON SOURCE LINES 137-177

.. code-block:: Python



    def example_read_video(video_object, start=0, end=None, read_video=True, read_audio=True):
        if end is None:
            end = float("inf")
        if end < start:
            raise ValueError(
                "end time should be larger than start time, got "
                f"start time={start} and end time={end}"
            )

        video_frames = torch.empty(0)
        video_pts = []
        if read_video:
            video_object.set_current_stream("video")
            frames = []
            for frame in itertools.takewhile(lambda x: x['pts'] <= end, video_object.seek(start)):
                frames.append(frame['data'])
                video_pts.append(frame['pts'])
            if len(frames) > 0:
                video_frames = torch.stack(frames, 0)

        audio_frames = torch.empty(0)
        audio_pts = []
        if read_audio:
            video_object.set_current_stream("audio")
            frames = []
            for frame in itertools.takewhile(lambda x: x['pts'] <= end, video_object.seek(start)):
                frames.append(frame['data'])
                audio_pts.append(frame['pts'])
            if len(frames) > 0:
                audio_frames = torch.cat(frames, 0)

        return video_frames, audio_frames, (video_pts, audio_pts), video_object.get_metadata()


    # Total number of frames should be 327 for video and 523264 datapoints for audio
    vf, af, info, meta = example_read_video(video)
    print(vf.size(), af.size())





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    torch.Size([327, 3, 256, 340]) torch.Size([523264, 1])




.. GENERATED FROM PYTHON SOURCE LINES 178-183

3. Building an example randomly sampled dataset (can be applied to training dataset of kinetics400)
-------------------------------------------------------------------------------------------------------
Cool, so now we can use the same principle to make the sample dataset.
We suggest trying out iterable dataset for this purpose.
Here, we are going to build an example dataset that reads randomly selected 10 frames of video.

.. GENERATED FROM PYTHON SOURCE LINES 185-186

Make sample dataset

.. GENERATED FROM PYTHON SOURCE LINES 186-191

.. code-block:: Python

    import os
    os.makedirs("./dataset", exist_ok=True)
    os.makedirs("./dataset/1", exist_ok=True)
    os.makedirs("./dataset/2", exist_ok=True)








.. GENERATED FROM PYTHON SOURCE LINES 192-193

Download the videos

.. GENERATED FROM PYTHON SOURCE LINES 193-219

.. code-block:: Python

    from torchvision.datasets.utils import download_url
    download_url(
        "https://github.com/pytorch/vision/blob/main/test/assets/videos/WUzgd7C1pWA.mp4?raw=true",
        "./dataset/1", "WUzgd7C1pWA.mp4"
    )
    download_url(
        "https://github.com/pytorch/vision/blob/main/test/assets/videos/RATRACE_wave_f_nm_np1_fr_goo_37.avi?raw=true",
        "./dataset/1",
        "RATRACE_wave_f_nm_np1_fr_goo_37.avi"
    )
    download_url(
        "https://github.com/pytorch/vision/blob/main/test/assets/videos/SOX5yA1l24A.mp4?raw=true",
        "./dataset/2",
        "SOX5yA1l24A.mp4"
    )
    download_url(
        "https://github.com/pytorch/vision/blob/main/test/assets/videos/v_SoccerJuggling_g23_c01.avi?raw=true",
        "./dataset/2",
        "v_SoccerJuggling_g23_c01.avi"
    )
    download_url(
        "https://github.com/pytorch/vision/blob/main/test/assets/videos/v_SoccerJuggling_g24_c01.avi?raw=true",
        "./dataset/2",
        "v_SoccerJuggling_g24_c01.avi"
    )





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Downloading https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/test/assets/videos/WUzgd7C1pWA.mp4 to ./dataset/1/WUzgd7C1pWA.mp4

    3.7%
    7.4%
    11.1%
    14.7%
    18.4%
    22.1%
    25.8%
    29.5%
    33.2%
    36.8%
    40.5%
    44.2%
    47.9%
    51.6%
    55.3%
    58.9%
    62.6%
    66.3%
    70.0%
    73.7%
    77.4%
    81.0%
    84.7%
    88.4%
    92.1%
    95.8%
    99.5%
    100.0%
    Downloading https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/test/assets/videos/RATRACE_wave_f_nm_np1_fr_goo_37.avi to ./dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi

    12.4%
    24.9%
    37.3%
    49.7%
    62.1%
    74.6%
    87.0%
    99.4%
    100.0%
    Downloading https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/test/assets/videos/SOX5yA1l24A.mp4 to ./dataset/2/SOX5yA1l24A.mp4

    5.8%
    11.7%
    17.5%
    23.4%
    29.2%
    35.1%
    40.9%
    46.8%
    52.6%
    58.5%
    64.3%
    70.2%
    76.0%
    81.9%
    87.7%
    93.6%
    99.4%
    100.0%
    Downloading https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/test/assets/videos/v_SoccerJuggling_g23_c01.avi to ./dataset/2/v_SoccerJuggling_g23_c01.avi

    6.4%
    12.9%
    19.3%
    25.8%
    32.2%
    38.7%
    45.1%
    51.6%
    58.0%
    64.5%
    70.9%
    77.3%
    83.8%
    90.2%
    96.7%
    100.0%
    Downloading https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/test/assets/videos/v_SoccerJuggling_g24_c01.avi to ./dataset/2/v_SoccerJuggling_g24_c01.avi

    5.3%
    10.5%
    15.8%
    21.0%
    26.3%
    31.6%
    36.8%
    42.1%
    47.3%
    52.6%
    57.9%
    63.1%
    68.4%
    73.6%
    78.9%
    84.2%
    89.4%
    94.7%
    99.9%
    100.0%




.. GENERATED FROM PYTHON SOURCE LINES 220-221

Housekeeping and utilities

.. GENERATED FROM PYTHON SOURCE LINES 221-239

.. code-block:: Python

    import os
    import random

    from torchvision.datasets.folder import make_dataset
    from torchvision import transforms as t


    def _find_classes(dir):
        classes = [d.name for d in os.scandir(dir) if d.is_dir()]
        classes.sort()
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx


    def get_samples(root, extensions=(".mp4", ".avi")):
        _, class_to_idx = _find_classes(root)
        return make_dataset(root, class_to_idx, extensions=extensions)








.. GENERATED FROM PYTHON SOURCE LINES 240-250

We are going to define the dataset and some basic arguments.
We assume the structure of the FolderDataset, and add the following parameters:

- ``clip_len``: length of a clip in frames
- ``frame_transform``: transform for every frame individually
- ``video_transform``: transform on a video sequence

.. note::
  We actually add epoch size as using :func:`~torch.utils.data.IterableDataset`
  class allows us to naturally oversample clips or images from each video if needed.

.. GENERATED FROM PYTHON SOURCE LINES 250-294

.. code-block:: Python



    class RandomDataset(torch.utils.data.IterableDataset):
        def __init__(self, root, epoch_size=None, frame_transform=None, video_transform=None, clip_len=16):
            super(RandomDataset).__init__()

            self.samples = get_samples(root)

            # Allow for temporal jittering
            if epoch_size is None:
                epoch_size = len(self.samples)
            self.epoch_size = epoch_size

            self.clip_len = clip_len
            self.frame_transform = frame_transform
            self.video_transform = video_transform

        def __iter__(self):
            for i in range(self.epoch_size):
                # Get random sample
                path, target = random.choice(self.samples)
                # Get video object
                vid = torchvision.io.VideoReader(path, "video")
                metadata = vid.get_metadata()
                video_frames = []  # video frame buffer

                # Seek and return frames
                max_seek = metadata["video"]['duration'][0] - (self.clip_len / metadata["video"]['fps'][0])
                start = random.uniform(0., max_seek)
                for frame in itertools.islice(vid.seek(start), self.clip_len):
                    video_frames.append(self.frame_transform(frame['data']))
                    current_pts = frame['pts']
                # Stack it into a tensor
                video = torch.stack(video_frames, 0)
                if self.video_transform:
                    video = self.video_transform(video)
                output = {
                    'path': path,
                    'video': video,
                    'target': target,
                    'start': start,
                    'end': current_pts}
                yield output








.. GENERATED FROM PYTHON SOURCE LINES 295-309

Given a path of videos in a folder structure, i.e:

- dataset
    - class 1
        - file 0
        - file 1
        - ...
    - class 2
        - file 0
        - file 1
        - ...
    - ...

We can generate a dataloader and test the dataset.

.. GENERATED FROM PYTHON SOURCE LINES 309-316

.. code-block:: Python



    transforms = [t.Resize((112, 112))]
    frame_transform = t.Compose(transforms)

    dataset = RandomDataset("./dataset", epoch_size=None, frame_transform=frame_transform)








.. GENERATED FROM PYTHON SOURCE LINES 317-328

.. code-block:: Python

    from torch.utils.data import DataLoader
    loader = DataLoader(dataset, batch_size=12)
    data = {"video": [], 'start': [], 'end': [], 'tensorsize': []}
    for batch in loader:
        for i in range(len(batch['path'])):
            data['video'].append(batch['path'][i])
            data['start'].append(batch['start'][i].item())
            data['end'].append(batch['end'][i].item())
            data['tensorsize'].append(batch['video'][i].size())
    print(data)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    {'video': ['./dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi', './dataset/2/SOX5yA1l24A.mp4', './dataset/1/WUzgd7C1pWA.mp4', './dataset/2/v_SoccerJuggling_g24_c01.avi', './dataset/2/v_SoccerJuggling_g23_c01.avi'], 'start': [0.5340898311348212, 6.514303497110079, 4.142249898746436, 7.604998914489408, 2.865074628013495], 'end': [1.066667, 7.040367, 4.671333, 8.1081, 3.370033], 'tensorsize': [torch.Size([16, 3, 112, 112]), torch.Size([16, 3, 112, 112]), torch.Size([16, 3, 112, 112]), torch.Size([16, 3, 112, 112]), torch.Size([16, 3, 112, 112])]}




.. GENERATED FROM PYTHON SOURCE LINES 329-332

4. Data Visualization
----------------------------------
Example of visualized video

.. GENERATED FROM PYTHON SOURCE LINES 332-341

.. code-block:: Python


    import matplotlib.pyplot as plt

    plt.figure(figsize=(12, 12))
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow(batch["video"][0, i, ...].permute(1, 2, 0))
        plt.axis("off")




.. image-sg:: /auto_examples/others/images/sphx_glr_plot_video_api_001.png
   :alt: plot video api
   :srcset: /auto_examples/others/images/sphx_glr_plot_video_api_001.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 342-343

Cleanup the video and dataset:

.. GENERATED FROM PYTHON SOURCE LINES 343-347

.. code-block:: Python

    import os
    import shutil
    os.remove("./WUzgd7C1pWA.mp4")
    shutil.rmtree("./dataset")








.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 4.896 seconds)


.. _sphx_glr_download_auto_examples_others_plot_video_api.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_video_api.ipynb <plot_video_api.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_video_api.py <plot_video_api.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: plot_video_api.zip <plot_video_api.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_