• Docs >
  • Using tensorclasses for datasets
Shortcuts

Using tensorclasses for datasets

In this tutorial we demonstrate how tensorclasses can be used to efficiently and transparently load and manage data inside a training pipeline. The tutorial is based heavily on the PyTorch Quickstart Tutorial, but modified to demonstrate use of tensorclass. See the related tutorial using TensorDict.

import torch
import torch.nn as nn

from tensordict import MemoryMappedTensor, tensorclass
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cpu

The torchvision.datasets module contains a number of convenient pre-prepared datasets. In this tutorial we’ll use the relatively simple FashionMNIST dataset. Each image is an item of clothing, the objective is to classify the type of clothing in the image (e.g. “Bag”, “Sneaker” etc.).

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 65.5k/26.4M [00:00<01:13, 361kB/s]
  1%|          | 229k/26.4M [00:00<00:38, 680kB/s]
  3%|▎         | 885k/26.4M [00:00<00:10, 2.53MB/s]
  7%|▋         | 1.93M/26.4M [00:00<00:05, 4.09MB/s]
 25%|██▍       | 6.59M/26.4M [00:00<00:01, 15.4MB/s]
 37%|███▋      | 9.70M/26.4M [00:00<00:01, 16.5MB/s]
 60%|██████    | 15.9M/26.4M [00:01<00:00, 23.1MB/s]
 84%|████████▍ | 22.2M/26.4M [00:01<00:00, 27.0MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 19.3MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 324kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|▏         | 65.5k/4.42M [00:00<00:12, 359kB/s]
  5%|▌         | 229k/4.42M [00:00<00:06, 675kB/s]
 21%|██        | 918k/4.42M [00:00<00:01, 2.09MB/s]
 83%|████████▎ | 3.67M/4.42M [00:00<00:00, 7.21MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.02MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 69.9MB/s]

Tensorclasses are dataclasses that expose dedicated tensor methods over its contents much like TensorDict. They are a good choice when the structure of the data you want to store is fixed and predictable.

As well as specifying the contents, we can also encapsulate related logic as custom methods when defining the class. In this case we’ll write a from_dataset classmethod that takes a dataset as input and creates a tensorclass containing the data from the dataset. We create memory-mapped tensors to hold the data. This will allow us to efficiently load batches of transformed data from disk rather than repeatedly load and transform individual images.

@tensorclass
class FashionMNISTData:
    images: torch.Tensor
    targets: torch.Tensor

    @classmethod
    def from_dataset(cls, dataset, device=None):
        data = cls(
            images=MemoryMappedTensor.empty(
                (len(dataset), *dataset[0][0].squeeze().shape), dtype=torch.float32
            ),
            targets=MemoryMappedTensor.empty((len(dataset),), dtype=torch.int64),
            batch_size=[len(dataset)],
            device=device,
        )
        for i, (image, target) in enumerate(dataset):
            data[i] = cls(images=image, targets=torch.tensor(target), batch_size=[])
        return data

We will create two tensorclasses, one each for the training and test data. Note that we incur some overhead here as we are looping over the entire dataset, transforming and saving to disk.

training_data_tc = FashionMNISTData.from_dataset(training_data, device=device)
test_data_tc = FashionMNISTData.from_dataset(test_data, device=device)

DataLoaders

We’ll create DataLoaders from the torchvision-provided Datasets, as well as from our memory-mapped TensorDicts.

Since TensorDict implements __len__ and __getitem__ (and also __getitems__) we can use it like a map-style Dataset and create a DataLoader directly from it. Note that because TensorDict can already handle batched indices, there is no need for collation, so we pass the identity function as collate_fn.

batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)  # noqa: TOR401
test_dataloader = DataLoader(test_data, batch_size=batch_size)  # noqa: TOR401

train_dataloader_tc = DataLoader(  # noqa: TOR401
    training_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)
test_dataloader_tc = DataLoader(  # noqa: TOR401
    test_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)

Model

We use the same model from the Quickstart Tutorial.

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


model = Net().to(device)
model_tc = Net().to(device)
model, model_tc
(Net(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
), Net(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
))

Optimizing the parameters

We’ll optimise the parameters of the model using stochastic gradient descent and cross-entropy loss.

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer_tc = torch.optim.SGD(model_tc.parameters(), lr=1e-3)


def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

The training loop for our tensorclass-based DataLoader is very similar, we just adjust how we unpack the data to the more explicit attribute-based retrieval offered by the tensorclass. The .contiguous() method loads the data stored in the memmap tensor.

def train_tc(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    for batch, data in enumerate(dataloader):
        X, y = data.images.contiguous(), data.targets.contiguous()

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )


def test_tc(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for batch in dataloader:
            X, y = batch.images.contiguous(), batch.targets.contiguous()

            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )


for d in train_dataloader_tc:
    print(d)
    break

import time

t0 = time.time()
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------")
    train_tc(train_dataloader_tc, model_tc, loss_fn, optimizer_tc)
    test_tc(test_dataloader_tc, model_tc, loss_fn)
print(f"Tensorclass training done! time: {time.time() - t0: 4.4f} s")

t0 = time.time()
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print(f"Training done! time: {time.time() - t0: 4.4f} s")
FashionMNISTData(
    images=Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False),
    targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([64]),
    device=cpu,
    is_shared=False)
Epoch 1
-------------------------
loss: 2.306911 [    0/60000]
loss: 2.290546 [ 6400/60000]
loss: 2.267823 [12800/60000]
loss: 2.268296 [19200/60000]
loss: 2.239312 [25600/60000]
loss: 2.222285 [32000/60000]
loss: 2.223311 [38400/60000]
loss: 2.189771 [44800/60000]
loss: 2.196405 [51200/60000]
loss: 2.161141 [57600/60000]
Test Error:
 Accuracy: 48.7%, Avg loss: 2.154104

Epoch 2
-------------------------
loss: 2.168520 [    0/60000]
loss: 2.153506 [ 6400/60000]
loss: 2.092801 [12800/60000]
loss: 2.109224 [19200/60000]
loss: 2.055430 [25600/60000]
loss: 2.006157 [32000/60000]
loss: 2.025786 [38400/60000]
loss: 1.945255 [44800/60000]
loss: 1.956291 [51200/60000]
loss: 1.879478 [57600/60000]
Test Error:
 Accuracy: 59.4%, Avg loss: 1.879949

Epoch 3
-------------------------
loss: 1.919434 [    0/60000]
loss: 1.884455 [ 6400/60000]
loss: 1.763246 [12800/60000]
loss: 1.796725 [19200/60000]
loss: 1.699208 [25600/60000]
loss: 1.656230 [32000/60000]
loss: 1.662529 [38400/60000]
loss: 1.567700 [44800/60000]
loss: 1.596351 [51200/60000]
loss: 1.477721 [57600/60000]
Test Error:
 Accuracy: 61.8%, Avg loss: 1.509531

Epoch 4
-------------------------
loss: 1.582884 [    0/60000]
loss: 1.547020 [ 6400/60000]
loss: 1.392604 [12800/60000]
loss: 1.457310 [19200/60000]
loss: 1.349238 [25600/60000]
loss: 1.348273 [32000/60000]
loss: 1.347054 [38400/60000]
loss: 1.278697 [44800/60000]
loss: 1.319937 [51200/60000]
loss: 1.205683 [57600/60000]
Test Error:
 Accuracy: 63.1%, Avg loss: 1.246076

Epoch 5
-------------------------
loss: 1.329433 [    0/60000]
loss: 1.309014 [ 6400/60000]
loss: 1.141340 [12800/60000]
loss: 1.241332 [19200/60000]
loss: 1.118123 [25600/60000]
loss: 1.152766 [32000/60000]
loss: 1.158768 [38400/60000]
loss: 1.103974 [44800/60000]
loss: 1.149251 [51200/60000]
loss: 1.051577 [57600/60000]
Test Error:
 Accuracy: 64.4%, Avg loss: 1.084601

Tensorclass training done! time:  8.6816 s
Epoch 1
-------------------------
loss: 2.307089 [    0/60000]
loss: 2.297157 [ 6400/60000]
loss: 2.284588 [12800/60000]
loss: 2.276901 [19200/60000]
loss: 2.242071 [25600/60000]
loss: 2.219066 [32000/60000]
loss: 2.220341 [38400/60000]
loss: 2.188907 [44800/60000]
loss: 2.183793 [51200/60000]
loss: 2.147927 [57600/60000]
Test Error:
 Accuracy: 43.2%, Avg loss: 2.146838

Epoch 2
-------------------------
loss: 2.158795 [    0/60000]
loss: 2.147090 [ 6400/60000]
loss: 2.094091 [12800/60000]
loss: 2.108519 [19200/60000]
loss: 2.044007 [25600/60000]
loss: 1.981364 [32000/60000]
loss: 2.003546 [38400/60000]
loss: 1.929537 [44800/60000]
loss: 1.926183 [51200/60000]
loss: 1.847433 [57600/60000]
Test Error:
 Accuracy: 57.0%, Avg loss: 1.857306

Epoch 3
-------------------------
loss: 1.891647 [    0/60000]
loss: 1.856303 [ 6400/60000]
loss: 1.746082 [12800/60000]
loss: 1.786176 [19200/60000]
loss: 1.664370 [25600/60000]
loss: 1.619249 [32000/60000]
loss: 1.633839 [38400/60000]
loss: 1.550666 [44800/60000]
loss: 1.565413 [51200/60000]
loss: 1.457971 [57600/60000]
Test Error:
 Accuracy: 62.1%, Avg loss: 1.485636

Epoch 4
-------------------------
loss: 1.555910 [    0/60000]
loss: 1.517564 [ 6400/60000]
loss: 1.375235 [12800/60000]
loss: 1.446436 [19200/60000]
loss: 1.323968 [25600/60000]
loss: 1.320687 [32000/60000]
loss: 1.333024 [38400/60000]
loss: 1.270114 [44800/60000]
loss: 1.298299 [51200/60000]
loss: 1.202938 [57600/60000]
Test Error:
 Accuracy: 63.9%, Avg loss: 1.229178

Epoch 5
-------------------------
loss: 1.310207 [    0/60000]
loss: 1.290053 [ 6400/60000]
loss: 1.126103 [12800/60000]
loss: 1.231878 [19200/60000]
loss: 1.110180 [25600/60000]
loss: 1.128504 [32000/60000]
loss: 1.154162 [38400/60000]
loss: 1.097552 [44800/60000]
loss: 1.130474 [51200/60000]
loss: 1.054337 [57600/60000]
Test Error:
 Accuracy: 65.1%, Avg loss: 1.072044

Training done! time:  34.6215 s

Total running time of the script: (1 minutes 1.114 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources