torch.hub =================================== Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility. Publishing models ----------------- Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights) to a github repository by adding a simple ``hubconf.py`` file; ``hubconf.py`` can have multiple entrypoints. Each entrypoint is defined as a python function with the following signature. :: def entrypoint_name(pretrained=False, *args, **kwargs): ... How to implement an entrypoint? ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Here is a code snipet from pytorch/vision repository, which specifies an entrypoint for ``resnet18`` model. You can see a full script in `pytorch/vision repo <https://github.com/pytorch/vision/blob/master/hubconf.py>`_ :: dependencies = ['torch', 'math'] def resnet18(pretrained=False, *args, **kwargs): """ Resnet18 model pretrained (bool): a recommended kwargs for all entrypoints args & kwargs are arguments for the function """ ######## Call the model in the repo ############### from torchvision.models.resnet import resnet18 as _resnet18 model = _resnet18(*args, **kwargs) ######## End of call ############################## # The following logic is REQUIRED if pretrained: # For weights saved in local repo # model.load_state_dict(<path_to_saved_file>) # For weights saved elsewhere checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' model.load_state_dict(model_zoo.load_url(checkpoint, progress=False)) return model - ``dependencies`` variable is a **list** of package names required to to run the model. - Pretrained weights can either be stored local in the github repo, or loadable by ``model_zoo.load()``. - ``pretrained`` controls whether to load the pre-trained weights provided by repo owners. - ``args`` and ``kwargs`` are passed along to the real callable function. - Docstring of the function works as a help message, explaining what does the model do and what are the allowed arguments. - Entrypoint function should **ALWAYS** return a model(nn.module). Important Notice ^^^^^^^^^^^^^^^^ - The published models should be at least in a branch/tag. It can't be a random commit. Loading models from Hub ----------------------- Users can load the pre-trained models using ``torch.hub.load()`` API. .. automodule:: torch.hub .. autofunction:: load Here's an example loading ``resnet18`` entrypoint from ``pytorch/vision`` repo. :: hub_model = hub.load( 'pytorch/vision:master', # repo_owner/repo_name:branch 'resnet18', # entrypoint 1234, # args for callable [not applicable to resnet] pretrained=True) # kwargs for callable Where are my downloaded model & weights saved? ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The locations are used in the order of - hub_dir: user specified path. It can be set in the following ways: - Setting the environment variable ``TORCH_HUB_DIR`` - Calling ``hub.set_dir(<PATH_TO_HUB_DIR>)`` - ``~/.torch/hub`` .. autofunction:: set_dir Caching logic ^^^^^^^^^^^^^ By default, we don't clean up files after loading it. Hub uses the cache by default if it already exists in ``hub_dir``. Users can force a reload by calling ``hub.load(..., force_reload=True)``. This will delete the existing github folder and downloaded weights, reinitialize a fresh download. This is useful when updates are published to the same branch, users can keep up with the latest release.