Note
Click here to download the full example code
Trainer Datasets Example¶
This is the datasets used for the training example. It’s using PyTorch Lightning libraries.
import os.path
import tarfile
from typing import Callable, Optional
import fsspec
import numpy
import pytorch_lightning as pl
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.datasets.folder import is_image_file
from tqdm import tqdm
This uses torchvision to define a dataset that we will then later use in our Pytorch Lightning data module.
class ImageFolderSamplesDataset(datasets.ImageFolder):
"""
ImageFolderSamplesDataset is a wrapper around ImageFolder that allows you to
limit the number of samples.
"""
def __init__(
self,
root: str,
transform: Optional[Callable[..., object]] = None,
num_samples: Optional[int] = None,
**kwargs: object,
) -> None:
"""
Args:
num_samples: optional. limits the size of the dataset
"""
super().__init__(root, transform=transform)
self.num_samples = num_samples
def __len__(self) -> int:
if self.num_samples is not None:
return self.num_samples
return super().__len__()
For easy of use, we define a lightning data module so we can reuse it across our trainer and other components that need to load data.
# pyre-fixme[13]: Attribute `test_ds` is never initialized.
# pyre-fixme[13]: Attribute `train_ds` is never initialized.
# pyre-fixme[13]: Attribute `val_ds` is never initialized.
class TinyImageNetDataModule(pl.LightningDataModule):
"""
TinyImageNetDataModule is a pytorch LightningDataModule for the tiny
imagenet dataset.
"""
train_ds: ImageFolderSamplesDataset
val_ds: ImageFolderSamplesDataset
test_ds: ImageFolderSamplesDataset
def __init__(
self, data_dir: str, batch_size: int = 16, num_samples: Optional[int] = None
) -> None:
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_samples = num_samples
def setup(self, stage: Optional[str] = None) -> None:
# Setup data loader and transforms
img_transform = transforms.Compose(
[
transforms.ToTensor(),
]
)
self.train_ds = ImageFolderSamplesDataset(
root=os.path.join(self.data_dir, "train"),
transform=img_transform,
num_samples=self.num_samples,
)
self.val_ds = ImageFolderSamplesDataset(
root=os.path.join(self.data_dir, "val"),
transform=img_transform,
num_samples=self.num_samples,
)
self.test_ds = ImageFolderSamplesDataset(
root=os.path.join(self.data_dir, "test"),
transform=img_transform,
num_samples=self.num_samples,
)
def train_dataloader(self) -> DataLoader:
return DataLoader(self.train_ds, batch_size=self.batch_size)
def val_dataloader(self) -> DataLoader:
return DataLoader(self.val_ds, batch_size=self.batch_size)
def test_dataloader(self) -> DataLoader:
return DataLoader(self.test_ds, batch_size=self.batch_size)
def teardown(self, stage: Optional[str] = None) -> None:
pass
To pass data between the different components we use fsspec which allows us to read/write to cloud or local file storage.
def download_data(remote_path: str, tmpdir: str) -> str:
"""
download_data downloads the training data from the specified remote path via
fsspec and places it in the tmpdir unextracted.
"""
if os.path.isdir(remote_path):
print("dataset path is a directory, using as is")
return remote_path
tar_path = os.path.join(tmpdir, "data.tar.gz")
print(f"downloading dataset from {remote_path} to {tar_path}...")
fs, _, rpaths = fsspec.get_fs_token_paths(remote_path)
assert len(rpaths) == 1, "must have single path"
fs.get(rpaths[0], tar_path)
data_path = os.path.join(tmpdir, "data")
print(f"extracting {tar_path} to {data_path}...")
with tarfile.open(tar_path, mode="r") as f:
f.extractall(data_path)
return data_path
def create_random_data(output_path: str, num_images: int = 250) -> None:
"""
Fills the given path with randomly generated 64x64 images.
This can be used for quick testing of the workflow of the model.
Does NOT pack the files into a tar, but does preprocess them.
"""
train_path = os.path.join(output_path, "train")
class1_train_path = os.path.join(train_path, "class1")
class2_train_path = os.path.join(train_path, "class2")
val_path = os.path.join(output_path, "val")
class1_val_path = os.path.join(val_path, "class1")
class2_val_path = os.path.join(val_path, "class2")
test_path = os.path.join(output_path, "test")
class1_test_path = os.path.join(test_path, "class1")
class2_test_path = os.path.join(test_path, "class2")
paths = [
class1_train_path,
class1_val_path,
class1_test_path,
class2_train_path,
class2_val_path,
class2_test_path,
]
for path in paths:
try:
os.makedirs(path)
except FileExistsError:
pass
for i in range(num_images):
pixels = numpy.random.rand(64, 64, 3) * 255
im = Image.fromarray(pixels.astype("uint8")).convert("RGB")
im.save(os.path.join(path, f"rand_image_{i}.jpeg"))
process_images(output_path)
def process_images(img_root: str) -> None:
print("transforming images...")
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
transforms.ToPILImage(),
]
)
image_files = []
for root, _, fnames in os.walk(img_root):
for fname in fnames:
path = os.path.join(root, fname)
if not is_image_file(path):
continue
image_files.append(path)
for path in tqdm(image_files, miniters=int(len(image_files) / 2000)):
f = Image.open(path)
f = transform(f)
f.save(path)
# sphinx_gallery_thumbnail_path = '_static/img/gallery-lib.png'
Total running time of the script: ( 0 minutes 0.000 seconds)