.. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_apps_lightning_model.py: Tiny ImageNet Model ==================== This is a toy model for doing regression on the tiny imagenet dataset. It's used by the apps in the same folder. .. code-block:: default import os.path import subprocess from typing import List, Optional, Tuple import fsspec import pytorch_lightning as pl import torch import torch.jit from torch.nn import functional as F from torchmetrics import Accuracy from torchvision.models.resnet import BasicBlock, ResNet class TinyImageNetModel(pl.LightningModule): """ An very simple linear model for the tiny image net dataset. """ def __init__( self, layer_sizes: Optional[List[int]] = None, lr: Optional[float] = None ) -> None: super().__init__() if not layer_sizes: layer_sizes = [1, 1, 1, 1] self.lr: float = lr or 0.001 # We use the torchvision resnet model with some small tweaks to match # TinyImageNet. m = ResNet(BasicBlock, layer_sizes) m.avgpool = torch.nn.AdaptiveAvgPool2d(1) m.fc.out_features = 200 self.model: ResNet = m self.train_acc = Accuracy() self.val_acc = Accuracy() # pyre-fixme[14] def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x) # pyre-fixme[14] def training_step( self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int ) -> torch.Tensor: return self._step("train", self.train_acc, batch, batch_idx) # pyre-fixme[14] def validation_step( self, val_batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int ) -> torch.Tensor: return self._step("val", self.val_acc, val_batch, batch_idx) def _step( self, step_name: str, acc_metric: Accuracy, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, ) -> torch.Tensor: x, y = batch y_pred = self(x) loss = F.cross_entropy(y_pred, y) self.log(f"{step_name}_loss", loss) acc_metric(y_pred, y) self.log(f"{step_name}_acc", acc_metric.compute()) return loss # pyre-fixme[3]: TODO(aivanou): Figure out why oss pyre can identify type but fb cannot. def configure_optimizers(self): return torch.optim.AdamW(self.parameters(), lr=self.lr) def export_inference_model( model: TinyImageNetModel, out_path: str, tmpdir: str ) -> None: """ export_inference_model uses TorchScript JIT to serialize the TinyImageNetModel into a standalone file that can be used during inference. TorchServe can also handle interpreted models with just the model.py file if your model can't be JITed. """ print("exporting inference model") jit_path = os.path.join(tmpdir, "model_jit.pt") jitted = torch.jit.script(model) print(f"saving JIT model to {jit_path}") torch.jit.save(jitted, jit_path) model_name = "tiny_image_net" mar_path = os.path.join(tmpdir, f"{model_name}.mar") print(f"creating model archive at {mar_path}") subprocess.run( [ "torch-model-archiver", "--model-name", "tiny_image_net", "--handler", "torchx/examples/apps/lightning/handler/handler.py", "--version", "1", "--serialized-file", jit_path, "--export-path", tmpdir, ], check=True, ) remote_path = os.path.join(out_path, "model.mar") print(f"uploading to {remote_path}") fs, _, rpaths = fsspec.get_fs_token_paths(remote_path) assert len(rpaths) == 1, "must have single path" fs.put(mar_path, rpaths[0]) # sphinx_gallery_thumbnail_path = '_static/img/gallery-lib.png' .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_examples_apps_lightning_model.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: model.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: model.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_