.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "tutorials/tensorclass_fashion.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_tutorials_tensorclass_fashion.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_tutorials_tensorclass_fashion.py:


Using tensorclasses for datasets
================================

.. GENERATED FROM PYTHON SOURCE LINES 7-13

In this tutorial we demonstrate how tensorclasses can be used to
efficiently and transparently load and manage data inside a training
pipeline. The tutorial is based heavily on the `PyTorch Quickstart
Tutorial <https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html>`__,
but modified to demonstrate use of tensorclass. See the related tutorial using
``TensorDict``.

.. GENERATED FROM PYTHON SOURCE LINES 13-27

.. code-block:: Python



    import torch
    import torch.nn as nn

    from tensordict import MemoryMappedTensor, tensorclass
    from torch.utils.data import DataLoader
    from torchvision import datasets
    from torchvision.transforms import ToTensor

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Using device: cpu




.. GENERATED FROM PYTHON SOURCE LINES 28-32

The ``torchvision.datasets`` module contains a number of convenient pre-prepared
datasets. In this tutorial we'll use the relatively simple FashionMNIST dataset. Each
image is an item of clothing, the objective is to classify the type of clothing in
the image (e.g. "Bag", "Sneaker" etc.).

.. GENERATED FROM PYTHON SOURCE LINES 32-46

.. code-block:: Python


    training_data = datasets.FashionMNIST(
        root="data",
        train=True,
        download=True,
        transform=ToTensor(),
    )
    test_data = datasets.FashionMNIST(
        root="data",
        train=False,
        download=True,
        transform=ToTensor(),
    )





.. rst-class:: sphx-glr-script-out

 .. code-block:: none


      0%|          | 0.00/26.4M [00:00<?, ?B/s]
      0%|          | 65.5k/26.4M [00:00<01:12, 362kB/s]
      1%|          | 229k/26.4M [00:00<00:38, 683kB/s] 
      3%|▎         | 918k/26.4M [00:00<00:09, 2.56MB/s]
      7%|▋         | 1.93M/26.4M [00:00<00:05, 4.12MB/s]
     25%|██▍       | 6.59M/26.4M [00:00<00:01, 15.1MB/s]
     38%|███▊      | 9.93M/26.4M [00:00<00:00, 17.2MB/s]
     59%|█████▉    | 15.5M/26.4M [00:01<00:00, 26.2MB/s]
     72%|███████▏  | 19.0M/26.4M [00:01<00:00, 24.6MB/s]
     93%|█████████▎| 24.4M/26.4M [00:01<00:00, 30.9MB/s]
    100%|██████████| 26.4M/26.4M [00:01<00:00, 19.4MB/s]

      0%|          | 0.00/29.5k [00:00<?, ?B/s]
    100%|██████████| 29.5k/29.5k [00:00<00:00, 325kB/s]

      0%|          | 0.00/4.42M [00:00<?, ?B/s]
      1%|▏         | 65.5k/4.42M [00:00<00:11, 364kB/s]
      5%|▌         | 229k/4.42M [00:00<00:06, 685kB/s] 
     21%|██▏       | 950k/4.42M [00:00<00:01, 2.20MB/s]
     87%|████████▋ | 3.83M/4.42M [00:00<00:00, 7.65MB/s]
    100%|██████████| 4.42M/4.42M [00:00<00:00, 6.12MB/s]

      0%|          | 0.00/5.15k [00:00<?, ?B/s]
    100%|██████████| 5.15k/5.15k [00:00<00:00, 66.8MB/s]




.. GENERATED FROM PYTHON SOURCE LINES 47-58

Tensorclasses are dataclasses that expose dedicated tensor methods over
its contents much like ``TensorDict``. They are a good choice when the
structure of the data you want to store is fixed and predictable.

As well as specifying the contents, we can also encapsulate related
logic as custom methods when defining the class. In this case we'll
write a ``from_dataset`` classmethod that takes a dataset as input and
creates a tensorclass containing the data from the dataset. We create
memory-mapped tensors to hold the data. This will allow us to
efficiently load batches of transformed data from disk rather than
repeatedly load and transform individual images.

.. GENERATED FROM PYTHON SOURCE LINES 58-80

.. code-block:: Python



    @tensorclass
    class FashionMNISTData:
        images: torch.Tensor
        targets: torch.Tensor

        @classmethod
        def from_dataset(cls, dataset, device=None):
            data = cls(
                images=MemoryMappedTensor.empty(
                    (len(dataset), *dataset[0][0].squeeze().shape), dtype=torch.float32
                ),
                targets=MemoryMappedTensor.empty((len(dataset),), dtype=torch.int64),
                batch_size=[len(dataset)],
                device=device,
            )
            for i, (image, target) in enumerate(dataset):
                data[i] = cls(images=image, targets=torch.tensor(target), batch_size=[])
            return data









.. GENERATED FROM PYTHON SOURCE LINES 81-84

We will create two tensorclasses, one each for the training and test data. Note that
we incur some overhead here as we are looping over the entire dataset, transforming
and saving to disk.

.. GENERATED FROM PYTHON SOURCE LINES 84-88

.. code-block:: Python


    training_data_tc = FashionMNISTData.from_dataset(training_data, device=device)
    test_data_tc = FashionMNISTData.from_dataset(test_data, device=device)








.. GENERATED FROM PYTHON SOURCE LINES 89-100

DataLoaders
----------------

We'll create DataLoaders from the ``torchvision``-provided Datasets, as
well as from our memory-mapped TensorDicts.

Since ``TensorDict`` implements ``__len__`` and ``__getitem__`` (and
also ``__getitems__``) we can use it like a map-style Dataset and create
a ``DataLoader`` directly from it. Note that because ``TensorDict`` can
already handle batched indices, there is no need for collation, so we
pass the identity function as ``collate_fn``.

.. GENERATED FROM PYTHON SOURCE LINES 100-113

.. code-block:: Python


    batch_size = 64

    train_dataloader = DataLoader(training_data, batch_size=batch_size)  # noqa: TOR401
    test_dataloader = DataLoader(test_data, batch_size=batch_size)  # noqa: TOR401

    train_dataloader_tc = DataLoader(  # noqa: TOR401
        training_data_tc, batch_size=batch_size, collate_fn=lambda x: x
    )
    test_dataloader_tc = DataLoader(  # noqa: TOR401
        test_data_tc, batch_size=batch_size, collate_fn=lambda x: x
    )








.. GENERATED FROM PYTHON SOURCE LINES 114-120

Model
-------

We use the same model from the
`Quickstart Tutorial <https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html>`__.


.. GENERATED FROM PYTHON SOURCE LINES 120-144

.. code-block:: Python



    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.flatten = nn.Flatten()
            self.linear_relu_stack = nn.Sequential(
                nn.Linear(28 * 28, 512),
                nn.ReLU(),
                nn.Linear(512, 512),
                nn.ReLU(),
                nn.Linear(512, 10),
            )

        def forward(self, x):
            x = self.flatten(x)
            logits = self.linear_relu_stack(x)
            return logits


    model = Net().to(device)
    model_tc = Net().to(device)
    model, model_tc





.. rst-class:: sphx-glr-script-out

 .. code-block:: none


    (Net(
      (flatten): Flatten(start_dim=1, end_dim=-1)
      (linear_relu_stack): Sequential(
        (0): Linear(in_features=784, out_features=512, bias=True)
        (1): ReLU()
        (2): Linear(in_features=512, out_features=512, bias=True)
        (3): ReLU()
        (4): Linear(in_features=512, out_features=10, bias=True)
      )
    ), Net(
      (flatten): Flatten(start_dim=1, end_dim=-1)
      (linear_relu_stack): Sequential(
        (0): Linear(in_features=784, out_features=512, bias=True)
        (1): ReLU()
        (2): Linear(in_features=512, out_features=512, bias=True)
        (3): ReLU()
        (4): Linear(in_features=512, out_features=10, bias=True)
      )
    ))



.. GENERATED FROM PYTHON SOURCE LINES 145-151

Optimizing the parameters
---------------------------------

We'll optimise the parameters of the model using stochastic gradient descent and
cross-entropy loss.


.. GENERATED FROM PYTHON SOURCE LINES 151-176

.. code-block:: Python


    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    optimizer_tc = torch.optim.SGD(model_tc.parameters(), lr=1e-3)


    def train(dataloader, model, loss_fn, optimizer):
        size = len(dataloader.dataset)
        model.train()

        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)

            pred = model(X)
            loss = loss_fn(pred, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch % 100 == 0:
                loss, current = loss.item(), batch * len(X)
                print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")









.. GENERATED FROM PYTHON SOURCE LINES 177-181

The training loop for our tensorclass-based DataLoader is very similar, we just
adjust how we unpack the data to the more explicit attribute-based retrieval offered
by the tensorclass. The ``.contiguous()`` method loads the data stored in the memmap
tensor.

.. GENERATED FROM PYTHON SOURCE LINES 181-267

.. code-block:: Python



    def train_tc(dataloader, model, loss_fn, optimizer):
        size = len(dataloader.dataset)
        model.train()

        for batch, data in enumerate(dataloader):
            X, y = data.images.contiguous(), data.targets.contiguous()

            pred = model(X)
            loss = loss_fn(pred, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch % 100 == 0:
                loss, current = loss.item(), batch * len(X)
                print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")


    def test(dataloader, model, loss_fn):
        size = len(dataloader.dataset)
        num_batches = len(dataloader)
        model.eval()
        test_loss, correct = 0, 0
        with torch.no_grad():
            for X, y in dataloader:
                X, y = X.to(device), y.to(device)

                pred = model(X)

                test_loss += loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()

        test_loss /= num_batches
        correct /= size

        print(
            f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
        )


    def test_tc(dataloader, model, loss_fn):
        size = len(dataloader.dataset)
        num_batches = len(dataloader)
        model.eval()
        test_loss, correct = 0, 0
        with torch.no_grad():
            for batch in dataloader:
                X, y = batch.images.contiguous(), batch.targets.contiguous()

                pred = model(X)

                test_loss += loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()

        test_loss /= num_batches
        correct /= size

        print(
            f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
        )


    for d in train_dataloader_tc:
        print(d)
        break

    import time

    t0 = time.time()
    epochs = 5
    for t in range(epochs):
        print(f"Epoch {t + 1}\n-------------------------")
        train_tc(train_dataloader_tc, model_tc, loss_fn, optimizer_tc)
        test_tc(test_dataloader_tc, model_tc, loss_fn)
    print(f"Tensorclass training done! time: {time.time() - t0: 4.4f} s")

    t0 = time.time()
    epochs = 5
    for t in range(epochs):
        print(f"Epoch {t + 1}\n-------------------------")
        train(train_dataloader, model, loss_fn, optimizer)
        test(test_dataloader, model, loss_fn)
    print(f"Training done! time: {time.time() - t0: 4.4f} s")




.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    FashionMNISTData(
        images=Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False),
        targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False),
        batch_size=torch.Size([64]),
        device=cpu,
        is_shared=False)
    Epoch 1
    -------------------------
    loss: 2.299970 [    0/60000]
    loss: 2.283205 [ 6400/60000]
    loss: 2.276911 [12800/60000]
    loss: 2.273206 [19200/60000]
    loss: 2.257264 [25600/60000]
    loss: 2.232306 [32000/60000]
    loss: 2.233045 [38400/60000]
    loss: 2.198727 [44800/60000]
    loss: 2.197931 [51200/60000]
    loss: 2.177116 [57600/60000]
    Test Error: 
     Accuracy: 48.4%, Avg loss: 2.166463 

    Epoch 2
    -------------------------
    loss: 2.173323 [    0/60000]
    loss: 2.155926 [ 6400/60000]
    loss: 2.109603 [12800/60000]
    loss: 2.127738 [19200/60000]
    loss: 2.079188 [25600/60000]
    loss: 2.033805 [32000/60000]
    loss: 2.049941 [38400/60000]
    loss: 1.971063 [44800/60000]
    loss: 1.978915 [51200/60000]
    loss: 1.923310 [57600/60000]
    Test Error: 
     Accuracy: 56.8%, Avg loss: 1.910700 

    Epoch 3
    -------------------------
    loss: 1.941580 [    0/60000]
    loss: 1.903466 [ 6400/60000]
    loss: 1.797659 [12800/60000]
    loss: 1.835826 [19200/60000]
    loss: 1.730940 [25600/60000]
    loss: 1.697201 [32000/60000]
    loss: 1.700962 [38400/60000]
    loss: 1.599460 [44800/60000]
    loss: 1.627200 [51200/60000]
    loss: 1.532528 [57600/60000]
    Test Error: 
     Accuracy: 60.7%, Avg loss: 1.542216 

    Epoch 4
    -------------------------
    loss: 1.607462 [    0/60000]
    loss: 1.561305 [ 6400/60000]
    loss: 1.421173 [12800/60000]
    loss: 1.491235 [19200/60000]
    loss: 1.373170 [25600/60000]
    loss: 1.376845 [32000/60000]
    loss: 1.375176 [38400/60000]
    loss: 1.294547 [44800/60000]
    loss: 1.334866 [51200/60000]
    loss: 1.244027 [57600/60000]
    Test Error: 
     Accuracy: 63.2%, Avg loss: 1.265811 

    Epoch 5
    -------------------------
    loss: 1.342325 [    0/60000]
    loss: 1.311911 [ 6400/60000]
    loss: 1.158748 [12800/60000]
    loss: 1.262786 [19200/60000]
    loss: 1.138063 [25600/60000]
    loss: 1.167982 [32000/60000]
    loss: 1.176513 [38400/60000]
    loss: 1.109314 [44800/60000]
    loss: 1.152379 [51200/60000]
    loss: 1.080400 [57600/60000]
    Test Error: 
     Accuracy: 64.8%, Avg loss: 1.095863 

    Tensorclass training done! time:  8.5706 s
    Epoch 1
    -------------------------
    loss: 2.296645 [    0/60000]
    loss: 2.287549 [ 6400/60000]
    loss: 2.264265 [12800/60000]
    loss: 2.256798 [19200/60000]
    loss: 2.246523 [25600/60000]
    loss: 2.219641 [32000/60000]
    loss: 2.219994 [38400/60000]
    loss: 2.184665 [44800/60000]
    loss: 2.184806 [51200/60000]
    loss: 2.153459 [57600/60000]
    Test Error: 
     Accuracy: 47.2%, Avg loss: 2.141223 

    Epoch 2
    -------------------------
    loss: 2.153499 [    0/60000]
    loss: 2.149285 [ 6400/60000]
    loss: 2.078696 [12800/60000]
    loss: 2.086227 [19200/60000]
    loss: 2.050871 [25600/60000]
    loss: 1.994655 [32000/60000]
    loss: 2.008462 [38400/60000]
    loss: 1.928025 [44800/60000]
    loss: 1.933047 [51200/60000]
    loss: 1.864366 [57600/60000]
    Test Error: 
     Accuracy: 59.0%, Avg loss: 1.852625 

    Epoch 3
    -------------------------
    loss: 1.891956 [    0/60000]
    loss: 1.871375 [ 6400/60000]
    loss: 1.735402 [12800/60000]
    loss: 1.762655 [19200/60000]
    loss: 1.676425 [25600/60000]
    loss: 1.636148 [32000/60000]
    loss: 1.639726 [38400/60000]
    loss: 1.544173 [44800/60000]
    loss: 1.575122 [51200/60000]
    loss: 1.474586 [57600/60000]
    Test Error: 
     Accuracy: 60.7%, Avg loss: 1.484805 

    Epoch 4
    -------------------------
    loss: 1.554681 [    0/60000]
    loss: 1.536302 [ 6400/60000]
    loss: 1.370105 [12800/60000]
    loss: 1.435142 [19200/60000]
    loss: 1.328619 [25600/60000]
    loss: 1.332409 [32000/60000]
    loss: 1.337541 [38400/60000]
    loss: 1.264824 [44800/60000]
    loss: 1.306673 [51200/60000]
    loss: 1.211130 [57600/60000]
    Test Error: 
     Accuracy: 63.3%, Avg loss: 1.231135 

    Epoch 5
    -------------------------
    loss: 1.306077 [    0/60000]
    loss: 1.306447 [ 6400/60000]
    loss: 1.125408 [12800/60000]
    loss: 1.225559 [19200/60000]
    loss: 1.103323 [25600/60000]
    loss: 1.138372 [32000/60000]
    loss: 1.152909 [38400/60000]
    loss: 1.094877 [44800/60000]
    loss: 1.137386 [51200/60000]
    loss: 1.053415 [57600/60000]
    Test Error: 
     Accuracy: 64.6%, Avg loss: 1.072053 

    Training done! time:  35.3641 s





.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (1 minutes 1.976 seconds)


.. _sphx_glr_download_tutorials_tensorclass_fashion.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: tensorclass_fashion.ipynb <tensorclass_fashion.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: tensorclass_fashion.py <tensorclass_fashion.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: tensorclass_fashion.zip <tensorclass_fashion.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_