• Docs >
  • Using TensorDict for datasets
Shortcuts

Using TensorDict for datasets

In this tutorial we demonstrate how TensorDict can be used to efficiently and transparently load and manage data inside a training pipeline. The tutorial is based heavily on the PyTorch Quickstart Tutorial, but modified to demonstrate use of TensorDict.

import torch
import torch.nn as nn

from tensordict import MemoryMappedTensor, TensorDict
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cpu

The torchvision.datasets module contains a number of convenient pre-prepared datasets. In this tutorial we’ll use the relatively simple FashionMNIST dataset. Each image is an item of clothing, the objective is to classify the type of clothing in the image (e.g. “Bag”, “Sneaker” etc.).

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

We will create two tensordicts, one each for the training and test data. We create memory-mapped tensors to hold the data. This will allow us to efficiently load batches of transformed data from disk rather than repeatedly load and transform individual images.

First we create the MemoryMappedTensor containers.

training_data_td = TensorDict(
    {
        "images": MemoryMappedTensor.empty(
            (len(training_data), *training_data[0][0].squeeze().shape),
            dtype=torch.float32,
        ),
        "targets": MemoryMappedTensor.empty((len(training_data),), dtype=torch.int64),
    },
    batch_size=[len(training_data)],
    device=device,
)
test_data_td = TensorDict(
    {
        "images": MemoryMappedTensor.empty(
            (len(test_data), *test_data[0][0].squeeze().shape), dtype=torch.float32
        ),
        "targets": MemoryMappedTensor.empty((len(test_data),), dtype=torch.int64),
    },
    batch_size=[len(test_data)],
    device=device,
)

Then we can iterate over the data to populate the memory-mapped tensors. This takes a bit of time, but performing the transforms up-front will save repeated effort during training later.

for i, (img, label) in enumerate(training_data):
    training_data_td[i] = TensorDict({"images": img, "targets": label}, [])

for i, (img, label) in enumerate(test_data):
    test_data_td[i] = TensorDict({"images": img, "targets": label}, [])

DataLoaders

We’ll create DataLoaders from the torchvision-provided Datasets, as well as from our memory-mapped TensorDicts.

Since TensorDict implements __len__ and __getitem__ (and also __getitems__) we can use it like a map-style Dataset and create a DataLoader directly from it. Note that because TensorDict can already handle batched indices, there is no need for collation, so we pass the identity function as collate_fn.

batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)  # noqa: TOR401
test_dataloader = DataLoader(test_data, batch_size=batch_size)  # noqa: TOR401

train_dataloader_td = DataLoader(  # noqa: TOR401
    training_data_td, batch_size=batch_size, collate_fn=lambda x: x
)
test_dataloader_td = DataLoader(  # noqa: TOR401
    test_data_td, batch_size=batch_size, collate_fn=lambda x: x
)

Model

We use the same model from the Quickstart Tutorial.

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


model = Net().to(device)
model_td = Net().to(device)
model, model_td
(Net(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
), Net(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
))

Optimizing the parameters

We’ll optimise the parameters of the model using stochastic gradient descent and cross-entropy loss.

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer_td = torch.optim.SGD(model_td.parameters(), lr=1e-3)


def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

The training loop for our TensorDict-based DataLoader is very similar, we just adjust how we unpack the data to the more explicit key-based retrieval offered by TensorDict. The .contiguous() method loads the data stored in the memmap tensor.

def train_td(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    for batch, data in enumerate(dataloader):
        X, y = data["images"].contiguous(), data["targets"].contiguous()

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )


def test_td(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for batch in dataloader:
            X, y = batch["images"].contiguous(), batch["targets"].contiguous()

            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )


for d in train_dataloader_td:
    print(d)
    break

import time

t0 = time.time()
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------")
    train_td(train_dataloader_td, model_td, loss_fn, optimizer_td)
    test_td(test_dataloader_td, model_td, loss_fn)
print(f"TensorDict training done! time: {time.time() - t0: 4.4f} s")

t0 = time.time()
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print(f"Training done! time: {time.time() - t0: 4.4f} s")
TensorDict(
    fields={
        images: Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False),
        targets: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([64]),
    device=cpu,
    is_shared=False)
Epoch 1
-------------------------
loss: 2.312291 [    0/60000]
loss: 2.299820 [ 6400/60000]
loss: 2.279324 [12800/60000]
loss: 2.268980 [19200/60000]
loss: 2.262558 [25600/60000]
loss: 2.231798 [32000/60000]
loss: 2.238440 [38400/60000]
loss: 2.211598 [44800/60000]
loss: 2.201264 [51200/60000]
loss: 2.174402 [57600/60000]
Test Error:
 Accuracy: 49.6%, Avg loss: 2.177515

Epoch 2
-------------------------
loss: 2.183485 [    0/60000]
loss: 2.179688 [ 6400/60000]
loss: 2.128562 [12800/60000]
loss: 2.143096 [19200/60000]
loss: 2.097443 [25600/60000]
loss: 2.041913 [32000/60000]
loss: 2.067561 [38400/60000]
loss: 1.999742 [44800/60000]
loss: 1.996043 [51200/60000]
loss: 1.932078 [57600/60000]
Test Error:
 Accuracy: 57.9%, Avg loss: 1.938806

Epoch 3
-------------------------
loss: 1.960048 [    0/60000]
loss: 1.941074 [ 6400/60000]
loss: 1.836948 [12800/60000]
loss: 1.876237 [19200/60000]
loss: 1.759434 [25600/60000]
loss: 1.709922 [32000/60000]
loss: 1.725742 [38400/60000]
loss: 1.631767 [44800/60000]
loss: 1.650623 [51200/60000]
loss: 1.542952 [57600/60000]
Test Error:
 Accuracy: 61.7%, Avg loss: 1.568830

Epoch 4
-------------------------
loss: 1.623320 [    0/60000]
loss: 1.597432 [ 6400/60000]
loss: 1.452819 [12800/60000]
loss: 1.523726 [19200/60000]
loss: 1.387065 [25600/60000]
loss: 1.385900 [32000/60000]
loss: 1.390101 [38400/60000]
loss: 1.318126 [44800/60000]
loss: 1.354831 [51200/60000]
loss: 1.246836 [57600/60000]
Test Error:
 Accuracy: 63.1%, Avg loss: 1.280337

Epoch 5
-------------------------
loss: 1.350684 [    0/60000]
loss: 1.340373 [ 6400/60000]
loss: 1.175442 [12800/60000]
loss: 1.281742 [19200/60000]
loss: 1.140790 [25600/60000]
loss: 1.172535 [32000/60000]
loss: 1.184256 [38400/60000]
loss: 1.123956 [44800/60000]
loss: 1.169118 [51200/60000]
loss: 1.078584 [57600/60000]
Test Error:
 Accuracy: 64.6%, Avg loss: 1.103163

TensorDict training done! time:  8.6454 s
Epoch 1
-------------------------
loss: 2.302963 [    0/60000]
loss: 2.286235 [ 6400/60000]
loss: 2.269734 [12800/60000]
loss: 2.267132 [19200/60000]
loss: 2.244352 [25600/60000]
loss: 2.218158 [32000/60000]
loss: 2.233029 [38400/60000]
loss: 2.191653 [44800/60000]
loss: 2.184664 [51200/60000]
loss: 2.168697 [57600/60000]
Test Error:
 Accuracy: 44.3%, Avg loss: 2.153297

Epoch 2
-------------------------
loss: 2.160024 [    0/60000]
loss: 2.150428 [ 6400/60000]
loss: 2.092832 [12800/60000]
loss: 2.111400 [19200/60000]
loss: 2.054063 [25600/60000]
loss: 1.997418 [32000/60000]
loss: 2.030121 [38400/60000]
loss: 1.941688 [44800/60000]
loss: 1.943831 [51200/60000]
loss: 1.891106 [57600/60000]
Test Error:
 Accuracy: 56.6%, Avg loss: 1.877111

Epoch 3
-------------------------
loss: 1.907028 [    0/60000]
loss: 1.880447 [ 6400/60000]
loss: 1.757670 [12800/60000]
loss: 1.801385 [19200/60000]
loss: 1.684063 [25600/60000]
loss: 1.643999 [32000/60000]
loss: 1.669101 [38400/60000]
loss: 1.563666 [44800/60000]
loss: 1.585065 [51200/60000]
loss: 1.494291 [57600/60000]
Test Error:
 Accuracy: 59.8%, Avg loss: 1.505017

Epoch 4
-------------------------
loss: 1.571528 [    0/60000]
loss: 1.537744 [ 6400/60000]
loss: 1.384650 [12800/60000]
loss: 1.455852 [19200/60000]
loss: 1.332916 [25600/60000]
loss: 1.336260 [32000/60000]
loss: 1.349962 [38400/60000]
loss: 1.272937 [44800/60000]
loss: 1.304907 [51200/60000]
loss: 1.213199 [57600/60000]
Test Error:
 Accuracy: 62.7%, Avg loss: 1.238966

Epoch 5
-------------------------
loss: 1.320472 [    0/60000]
loss: 1.296768 [ 6400/60000]
loss: 1.133535 [12800/60000]
loss: 1.232609 [19200/60000]
loss: 1.109402 [25600/60000]
loss: 1.138454 [32000/60000]
loss: 1.158809 [38400/60000]
loss: 1.096638 [44800/60000]
loss: 1.132007 [51200/60000]
loss: 1.054023 [57600/60000]
Test Error:
 Accuracy: 64.3%, Avg loss: 1.076223

Training done! time:  33.5931 s

Total running time of the script: (0 minutes 51.615 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources