Note
Go to the end to download the full example code.
Using TensorDict for datasets¶
In this tutorial we demonstrate how TensorDict
can be used to
efficiently and transparently load and manage data inside a training
pipeline. The tutorial is based heavily on the PyTorch Quickstart
Tutorial,
but modified to demonstrate use of TensorDict
.
import torch
import torch.nn as nn
from tensordict import MemoryMappedTensor, TensorDict
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cpu
The torchvision.datasets
module contains a number of convenient pre-prepared
datasets. In this tutorial we’ll use the relatively simple FashionMNIST dataset. Each
image is an item of clothing, the objective is to classify the type of clothing in
the image (e.g. “Bag”, “Sneaker” etc.).
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
We will create two tensordicts, one each for the training and test data. We create memory-mapped tensors to hold the data. This will allow us to efficiently load batches of transformed data from disk rather than repeatedly load and transform individual images.
First we create the MemoryMappedTensor
containers.
training_data_td = TensorDict(
{
"images": MemoryMappedTensor.empty(
(len(training_data), *training_data[0][0].squeeze().shape),
dtype=torch.float32,
),
"targets": MemoryMappedTensor.empty((len(training_data),), dtype=torch.int64),
},
batch_size=[len(training_data)],
device=device,
)
test_data_td = TensorDict(
{
"images": MemoryMappedTensor.empty(
(len(test_data), *test_data[0][0].squeeze().shape), dtype=torch.float32
),
"targets": MemoryMappedTensor.empty((len(test_data),), dtype=torch.int64),
},
batch_size=[len(test_data)],
device=device,
)
Then we can iterate over the data to populate the memory-mapped tensors. This takes a bit of time, but performing the transforms up-front will save repeated effort during training later.
DataLoaders¶
We’ll create DataLoaders from the torchvision
-provided Datasets, as well as from
our memory-mapped TensorDicts.
Since TensorDict
implements __len__
and __getitem__
(and also
__getitems__
) we can use it like a map-style Dataset and create a DataLoader
directly from it. Note that because TensorDict
can already handle batched indices,
there is no need for collation, so we pass the identity function as collate_fn
.
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size) # noqa: TOR401
test_dataloader = DataLoader(test_data, batch_size=batch_size) # noqa: TOR401
train_dataloader_td = DataLoader( # noqa: TOR401
training_data_td, batch_size=batch_size, collate_fn=lambda x: x
)
test_dataloader_td = DataLoader( # noqa: TOR401
test_data_td, batch_size=batch_size, collate_fn=lambda x: x
)
Model¶
We use the same model from the Quickstart Tutorial.
class Net(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = Net().to(device)
model_td = Net().to(device)
model, model_td
(Net(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
), Net(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
))
Optimizing the parameters¶
We’ll optimise the parameters of the model using stochastic gradient descent and cross-entropy loss.
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer_td = torch.optim.SGD(model_td.parameters(), lr=1e-3)
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
The training loop for our TensorDict
-based DataLoader is very similar, we just
adjust how we unpack the data to the more explicit key-based retrieval offered by
TensorDict
. The .contiguous()
method loads the data stored in the memmap tensor.
def train_td(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, data in enumerate(dataloader):
X, y = data["images"].contiguous(), data["targets"].contiguous()
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
def test_td(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for batch in dataloader:
X, y = batch["images"].contiguous(), batch["targets"].contiguous()
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
for d in train_dataloader_td:
print(d)
break
import time
t0 = time.time()
epochs = 5
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------")
train_td(train_dataloader_td, model_td, loss_fn, optimizer_td)
test_td(test_dataloader_td, model_td, loss_fn)
print(f"TensorDict training done! time: {time.time() - t0: 4.4f} s")
t0 = time.time()
epochs = 5
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print(f"Training done! time: {time.time() - t0: 4.4f} s")
TensorDict(
fields={
images: Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False),
targets: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([64]),
device=cpu,
is_shared=False)
Epoch 1
-------------------------
loss: 2.293629 [ 0/60000]
loss: 2.283033 [ 6400/60000]
loss: 2.256484 [12800/60000]
loss: 2.256073 [19200/60000]
loss: 2.243077 [25600/60000]
loss: 2.200774 [32000/60000]
loss: 2.217680 [38400/60000]
loss: 2.179253 [44800/60000]
loss: 2.169554 [51200/60000]
loss: 2.139296 [57600/60000]
Test Error:
Accuracy: 44.9%, Avg loss: 2.136127
Epoch 2
-------------------------
loss: 2.144999 [ 0/60000]
loss: 2.140007 [ 6400/60000]
loss: 2.071116 [12800/60000]
loss: 2.092876 [19200/60000]
loss: 2.046267 [25600/60000]
loss: 1.972650 [32000/60000]
loss: 2.009594 [38400/60000]
loss: 1.925318 [44800/60000]
loss: 1.919691 [51200/60000]
loss: 1.849112 [57600/60000]
Test Error:
Accuracy: 59.9%, Avg loss: 1.853141
Epoch 3
-------------------------
loss: 1.884452 [ 0/60000]
loss: 1.857988 [ 6400/60000]
loss: 1.732184 [12800/60000]
loss: 1.780074 [19200/60000]
loss: 1.670048 [25600/60000]
loss: 1.618595 [32000/60000]
loss: 1.644744 [38400/60000]
loss: 1.546432 [44800/60000]
loss: 1.562902 [51200/60000]
loss: 1.465493 [57600/60000]
Test Error:
Accuracy: 61.8%, Avg loss: 1.488911
Epoch 4
-------------------------
loss: 1.553364 [ 0/60000]
loss: 1.522695 [ 6400/60000]
loss: 1.372123 [12800/60000]
loss: 1.449625 [19200/60000]
loss: 1.332421 [25600/60000]
loss: 1.329863 [32000/60000]
loss: 1.347445 [38400/60000]
loss: 1.271675 [44800/60000]
loss: 1.303136 [51200/60000]
loss: 1.211887 [57600/60000]
Test Error:
Accuracy: 63.5%, Avg loss: 1.239002
Epoch 5
-------------------------
loss: 1.314920 [ 0/60000]
loss: 1.296802 [ 6400/60000]
loss: 1.132275 [12800/60000]
loss: 1.240491 [19200/60000]
loss: 1.118523 [25600/60000]
loss: 1.144619 [32000/60000]
loss: 1.168913 [38400/60000]
loss: 1.102033 [44800/60000]
loss: 1.141350 [51200/60000]
loss: 1.064545 [57600/60000]
Test Error:
Accuracy: 64.8%, Avg loss: 1.083915
TensorDict training done! time: 8.6170 s
Epoch 1
-------------------------
loss: 2.290858 [ 0/60000]
loss: 2.286275 [ 6400/60000]
loss: 2.267298 [12800/60000]
loss: 2.265695 [19200/60000]
loss: 2.254766 [25600/60000]
loss: 2.215138 [32000/60000]
loss: 2.229925 [38400/60000]
loss: 2.191596 [44800/60000]
loss: 2.188072 [51200/60000]
loss: 2.166435 [57600/60000]
Test Error:
Accuracy: 43.9%, Avg loss: 2.155688
Epoch 2
-------------------------
loss: 2.158259 [ 0/60000]
loss: 2.155049 [ 6400/60000]
loss: 2.093325 [12800/60000]
loss: 2.112539 [19200/60000]
loss: 2.072581 [25600/60000]
loss: 1.997474 [32000/60000]
loss: 2.036950 [38400/60000]
loss: 1.950100 [44800/60000]
loss: 1.957567 [51200/60000]
loss: 1.895268 [57600/60000]
Test Error:
Accuracy: 55.7%, Avg loss: 1.885085
Epoch 3
-------------------------
loss: 1.911978 [ 0/60000]
loss: 1.891730 [ 6400/60000]
loss: 1.763546 [12800/60000]
loss: 1.807613 [19200/60000]
loss: 1.707991 [25600/60000]
loss: 1.642813 [32000/60000]
loss: 1.683022 [38400/60000]
loss: 1.573177 [44800/60000]
loss: 1.606999 [51200/60000]
loss: 1.511539 [57600/60000]
Test Error:
Accuracy: 61.3%, Avg loss: 1.516370
Epoch 4
-------------------------
loss: 1.578484 [ 0/60000]
loss: 1.552396 [ 6400/60000]
loss: 1.389732 [12800/60000]
loss: 1.469645 [19200/60000]
loss: 1.355540 [25600/60000]
loss: 1.340846 [32000/60000]
loss: 1.369824 [38400/60000]
loss: 1.286581 [44800/60000]
loss: 1.330394 [51200/60000]
loss: 1.239374 [57600/60000]
Test Error:
Accuracy: 63.4%, Avg loss: 1.253140
Epoch 5
-------------------------
loss: 1.325610 [ 0/60000]
loss: 1.314812 [ 6400/60000]
loss: 1.139354 [12800/60000]
loss: 1.251529 [19200/60000]
loss: 1.128213 [25600/60000]
loss: 1.147657 [32000/60000]
loss: 1.176389 [38400/60000]
loss: 1.109371 [44800/60000]
loss: 1.156932 [51200/60000]
loss: 1.078585 [57600/60000]
Test Error:
Accuracy: 64.4%, Avg loss: 1.089056
Training done! time: 34.5017 s
Total running time of the script: (0 minutes 55.621 seconds)