Note
Go to the end to download the full example code.
Using tensorclasses for datasets¶
In this tutorial we demonstrate how tensorclasses can be used to
efficiently and transparently load and manage data inside a training
pipeline. The tutorial is based heavily on the PyTorch Quickstart
Tutorial,
but modified to demonstrate use of tensorclass. See the related tutorial using
TensorDict
.
import torch
import torch.nn as nn
from tensordict import MemoryMappedTensor, tensorclass
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cpu
The torchvision.datasets
module contains a number of convenient pre-prepared
datasets. In this tutorial we’ll use the relatively simple FashionMNIST dataset. Each
image is an item of clothing, the objective is to classify the type of clothing in
the image (e.g. “Bag”, “Sneaker” etc.).
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
0%| | 0.00/26.4M [00:00<?, ?B/s]
0%| | 65.5k/26.4M [00:00<01:13, 361kB/s]
1%| | 229k/26.4M [00:00<00:38, 680kB/s]
3%|▎ | 885k/26.4M [00:00<00:10, 2.53MB/s]
7%|▋ | 1.93M/26.4M [00:00<00:05, 4.09MB/s]
25%|██▍ | 6.59M/26.4M [00:00<00:01, 15.4MB/s]
37%|███▋ | 9.70M/26.4M [00:00<00:01, 16.5MB/s]
60%|██████ | 15.9M/26.4M [00:01<00:00, 23.1MB/s]
84%|████████▍ | 22.2M/26.4M [00:01<00:00, 27.0MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 19.3MB/s]
0%| | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 324kB/s]
0%| | 0.00/4.42M [00:00<?, ?B/s]
1%|▏ | 65.5k/4.42M [00:00<00:12, 359kB/s]
5%|▌ | 229k/4.42M [00:00<00:06, 675kB/s]
21%|██ | 918k/4.42M [00:00<00:01, 2.09MB/s]
83%|████████▎ | 3.67M/4.42M [00:00<00:00, 7.21MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.02MB/s]
0%| | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 69.9MB/s]
Tensorclasses are dataclasses that expose dedicated tensor methods over
its contents much like TensorDict
. They are a good choice when the
structure of the data you want to store is fixed and predictable.
As well as specifying the contents, we can also encapsulate related
logic as custom methods when defining the class. In this case we’ll
write a from_dataset
classmethod that takes a dataset as input and
creates a tensorclass containing the data from the dataset. We create
memory-mapped tensors to hold the data. This will allow us to
efficiently load batches of transformed data from disk rather than
repeatedly load and transform individual images.
@tensorclass
class FashionMNISTData:
images: torch.Tensor
targets: torch.Tensor
@classmethod
def from_dataset(cls, dataset, device=None):
data = cls(
images=MemoryMappedTensor.empty(
(len(dataset), *dataset[0][0].squeeze().shape), dtype=torch.float32
),
targets=MemoryMappedTensor.empty((len(dataset),), dtype=torch.int64),
batch_size=[len(dataset)],
device=device,
)
for i, (image, target) in enumerate(dataset):
data[i] = cls(images=image, targets=torch.tensor(target), batch_size=[])
return data
We will create two tensorclasses, one each for the training and test data. Note that we incur some overhead here as we are looping over the entire dataset, transforming and saving to disk.
DataLoaders¶
We’ll create DataLoaders from the torchvision
-provided Datasets, as
well as from our memory-mapped TensorDicts.
Since TensorDict
implements __len__
and __getitem__
(and
also __getitems__
) we can use it like a map-style Dataset and create
a DataLoader
directly from it. Note that because TensorDict
can
already handle batched indices, there is no need for collation, so we
pass the identity function as collate_fn
.
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size) # noqa: TOR401
test_dataloader = DataLoader(test_data, batch_size=batch_size) # noqa: TOR401
train_dataloader_tc = DataLoader( # noqa: TOR401
training_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)
test_dataloader_tc = DataLoader( # noqa: TOR401
test_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)
Model¶
We use the same model from the Quickstart Tutorial.
class Net(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = Net().to(device)
model_tc = Net().to(device)
model, model_tc
(Net(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
), Net(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
))
Optimizing the parameters¶
We’ll optimise the parameters of the model using stochastic gradient descent and cross-entropy loss.
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer_tc = torch.optim.SGD(model_tc.parameters(), lr=1e-3)
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
The training loop for our tensorclass-based DataLoader is very similar, we just
adjust how we unpack the data to the more explicit attribute-based retrieval offered
by the tensorclass. The .contiguous()
method loads the data stored in the memmap
tensor.
def train_tc(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, data in enumerate(dataloader):
X, y = data.images.contiguous(), data.targets.contiguous()
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
def test_tc(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for batch in dataloader:
X, y = batch.images.contiguous(), batch.targets.contiguous()
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
for d in train_dataloader_tc:
print(d)
break
import time
t0 = time.time()
epochs = 5
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------")
train_tc(train_dataloader_tc, model_tc, loss_fn, optimizer_tc)
test_tc(test_dataloader_tc, model_tc, loss_fn)
print(f"Tensorclass training done! time: {time.time() - t0: 4.4f} s")
t0 = time.time()
epochs = 5
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print(f"Training done! time: {time.time() - t0: 4.4f} s")
FashionMNISTData(
images=Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False),
targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False),
batch_size=torch.Size([64]),
device=cpu,
is_shared=False)
Epoch 1
-------------------------
loss: 2.306911 [ 0/60000]
loss: 2.290546 [ 6400/60000]
loss: 2.267823 [12800/60000]
loss: 2.268296 [19200/60000]
loss: 2.239312 [25600/60000]
loss: 2.222285 [32000/60000]
loss: 2.223311 [38400/60000]
loss: 2.189771 [44800/60000]
loss: 2.196405 [51200/60000]
loss: 2.161141 [57600/60000]
Test Error:
Accuracy: 48.7%, Avg loss: 2.154104
Epoch 2
-------------------------
loss: 2.168520 [ 0/60000]
loss: 2.153506 [ 6400/60000]
loss: 2.092801 [12800/60000]
loss: 2.109224 [19200/60000]
loss: 2.055430 [25600/60000]
loss: 2.006157 [32000/60000]
loss: 2.025786 [38400/60000]
loss: 1.945255 [44800/60000]
loss: 1.956291 [51200/60000]
loss: 1.879478 [57600/60000]
Test Error:
Accuracy: 59.4%, Avg loss: 1.879949
Epoch 3
-------------------------
loss: 1.919434 [ 0/60000]
loss: 1.884455 [ 6400/60000]
loss: 1.763246 [12800/60000]
loss: 1.796725 [19200/60000]
loss: 1.699208 [25600/60000]
loss: 1.656230 [32000/60000]
loss: 1.662529 [38400/60000]
loss: 1.567700 [44800/60000]
loss: 1.596351 [51200/60000]
loss: 1.477721 [57600/60000]
Test Error:
Accuracy: 61.8%, Avg loss: 1.509531
Epoch 4
-------------------------
loss: 1.582884 [ 0/60000]
loss: 1.547020 [ 6400/60000]
loss: 1.392604 [12800/60000]
loss: 1.457310 [19200/60000]
loss: 1.349238 [25600/60000]
loss: 1.348273 [32000/60000]
loss: 1.347054 [38400/60000]
loss: 1.278697 [44800/60000]
loss: 1.319937 [51200/60000]
loss: 1.205683 [57600/60000]
Test Error:
Accuracy: 63.1%, Avg loss: 1.246076
Epoch 5
-------------------------
loss: 1.329433 [ 0/60000]
loss: 1.309014 [ 6400/60000]
loss: 1.141340 [12800/60000]
loss: 1.241332 [19200/60000]
loss: 1.118123 [25600/60000]
loss: 1.152766 [32000/60000]
loss: 1.158768 [38400/60000]
loss: 1.103974 [44800/60000]
loss: 1.149251 [51200/60000]
loss: 1.051577 [57600/60000]
Test Error:
Accuracy: 64.4%, Avg loss: 1.084601
Tensorclass training done! time: 8.6816 s
Epoch 1
-------------------------
loss: 2.307089 [ 0/60000]
loss: 2.297157 [ 6400/60000]
loss: 2.284588 [12800/60000]
loss: 2.276901 [19200/60000]
loss: 2.242071 [25600/60000]
loss: 2.219066 [32000/60000]
loss: 2.220341 [38400/60000]
loss: 2.188907 [44800/60000]
loss: 2.183793 [51200/60000]
loss: 2.147927 [57600/60000]
Test Error:
Accuracy: 43.2%, Avg loss: 2.146838
Epoch 2
-------------------------
loss: 2.158795 [ 0/60000]
loss: 2.147090 [ 6400/60000]
loss: 2.094091 [12800/60000]
loss: 2.108519 [19200/60000]
loss: 2.044007 [25600/60000]
loss: 1.981364 [32000/60000]
loss: 2.003546 [38400/60000]
loss: 1.929537 [44800/60000]
loss: 1.926183 [51200/60000]
loss: 1.847433 [57600/60000]
Test Error:
Accuracy: 57.0%, Avg loss: 1.857306
Epoch 3
-------------------------
loss: 1.891647 [ 0/60000]
loss: 1.856303 [ 6400/60000]
loss: 1.746082 [12800/60000]
loss: 1.786176 [19200/60000]
loss: 1.664370 [25600/60000]
loss: 1.619249 [32000/60000]
loss: 1.633839 [38400/60000]
loss: 1.550666 [44800/60000]
loss: 1.565413 [51200/60000]
loss: 1.457971 [57600/60000]
Test Error:
Accuracy: 62.1%, Avg loss: 1.485636
Epoch 4
-------------------------
loss: 1.555910 [ 0/60000]
loss: 1.517564 [ 6400/60000]
loss: 1.375235 [12800/60000]
loss: 1.446436 [19200/60000]
loss: 1.323968 [25600/60000]
loss: 1.320687 [32000/60000]
loss: 1.333024 [38400/60000]
loss: 1.270114 [44800/60000]
loss: 1.298299 [51200/60000]
loss: 1.202938 [57600/60000]
Test Error:
Accuracy: 63.9%, Avg loss: 1.229178
Epoch 5
-------------------------
loss: 1.310207 [ 0/60000]
loss: 1.290053 [ 6400/60000]
loss: 1.126103 [12800/60000]
loss: 1.231878 [19200/60000]
loss: 1.110180 [25600/60000]
loss: 1.128504 [32000/60000]
loss: 1.154162 [38400/60000]
loss: 1.097552 [44800/60000]
loss: 1.130474 [51200/60000]
loss: 1.054337 [57600/60000]
Test Error:
Accuracy: 65.1%, Avg loss: 1.072044
Training done! time: 34.6215 s
Total running time of the script: (1 minutes 1.114 seconds)