Note
Click here to download the full example code
Learn the Basics || Quickstart || Tensors || Datasets & DataLoaders || Transforms || Build Model || Autograd || Optimization || Save & Load Model
Transforms¶
Data does not always come in its final processed form that is required for training machine learning algorithms. We use transforms to perform some manipulation of the data and make it suitable for training.
All TorchVision datasets have two parameters -transform
to modify the features and
target_transform
to modify the labels - that accept callables containing the transformation logic.
The torchvision.transforms module offers
several commonly-used transforms out of the box.
The FashionMNIST features are in PIL Image format, and the labels are integers.
For training, we need the features as normalized tensors, and the labels as one-hot encoded tensors.
To make these transformations, we use ToTensor
and Lambda
.
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
0%| | 0.00/26.4M [00:00<?, ?B/s]
0%| | 65.5k/26.4M [00:00<01:12, 362kB/s]
1%| | 229k/26.4M [00:00<00:38, 682kB/s]
4%|3 | 950k/26.4M [00:00<00:11, 2.19MB/s]
15%|#4 | 3.83M/26.4M [00:00<00:02, 7.60MB/s]
38%|###8 | 10.1M/26.4M [00:00<00:00, 17.2MB/s]
61%|###### | 16.1M/26.4M [00:01<00:00, 24.5MB/s]
73%|#######2 | 19.2M/26.4M [00:01<00:00, 24.2MB/s]
95%|#########4| 25.1M/26.4M [00:01<00:00, 29.5MB/s]
100%|##########| 26.4M/26.4M [00:01<00:00, 19.4MB/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0.00/29.5k [00:00<?, ?B/s]
100%|##########| 29.5k/29.5k [00:00<00:00, 325kB/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0.00/4.42M [00:00<?, ?B/s]
1%|1 | 65.5k/4.42M [00:00<00:12, 361kB/s]
5%|5 | 229k/4.42M [00:00<00:06, 680kB/s]
20%|## | 885k/4.42M [00:00<00:01, 2.47MB/s]
43%|####2 | 1.90M/4.42M [00:00<00:00, 4.03MB/s]
100%|##########| 4.42M/4.42M [00:00<00:00, 6.07MB/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0.00/5.15k [00:00<?, ?B/s]
100%|##########| 5.15k/5.15k [00:00<00:00, 46.7MB/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
ToTensor()¶
ToTensor
converts a PIL image or NumPy ndarray
into a FloatTensor
. and scales
the image’s pixel intensity values in the range [0., 1.]
Lambda Transforms¶
Lambda transforms apply any user-defined lambda function. Here, we define a function
to turn the integer into a one-hot encoded tensor.
It first creates a zero tensor of size 10 (the number of labels in our dataset) and calls
scatter_ which assigns a
value=1
on the index as given by the label y
.
target_transform = Lambda(lambda y: torch.zeros(
10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
Further Reading¶
Total running time of the script: ( 0 minutes 4.407 seconds)