torch.hub
Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility.
Publishing models
Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights)
to a github repository by adding a simple hubconf.py
file;
hubconf.py
can have multiple entrypoints. Each entrypoint is defined as a python function with
the following signature.
def entrypoint_name(pretrained=False, *args, **kwargs):
...
How to implement an entrypoint?
Here is a code snipet from pytorch/vision repository, which specifies an entrypoint
for resnet18
model. You can see a full script in
pytorch/vision repo
dependencies = ['torch', 'math']
def resnet18(pretrained=False, *args, **kwargs):
"""
Resnet18 model
pretrained (bool): a recommended kwargs for all entrypoints
args & kwargs are arguments for the function
"""
######## Call the model in the repo ###############
from torchvision.models.resnet import resnet18 as _resnet18
model = _resnet18(*args, **kwargs)
######## End of call ##############################
# The following logic is REQUIRED
if pretrained:
# For weights saved in local repo
# model.load_state_dict(<path_to_saved_file>)
# For weights saved elsewhere
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
model.load_state_dict(model_zoo.load_url(checkpoint, progress=False))
return model
dependencies
variable is a list of package names required to to run the model.- Pretrained weights can either be stored local in the github repo, or loadable by
model_zoo.load()
. pretrained
controls whether to load the pre-trained weights provided by repo owners.args
andkwargs
are passed along to the real callable function.- Docstring of the function works as a help message, explaining what does the model do and what are the allowed arguments.
- Entrypoint function should ALWAYS return a model(nn.module).
Loading models from Hub
Users can load the pre-trained models using torch.hub.load()
API.
-
torch.hub.
load
(github, model, force_reload=False, *args, **kwargs)[source] Load a model from a github repo, with pretrained weights.
Parameters: - github – Required, a string with format “repo_owner/repo_name[:tag_name]” with an optional tag/branch. The default branch is master if not specified. Example: ‘pytorch/vision[:hub]’
- model – Required, a string of entrypoint name defined in repo’s hubconf.py
- force_reload – Optional, whether to discard the existing cache and force a fresh download. Default is False.
- *args – Optional, the corresponding args for callable model.
- **kwargs – Optional, the corresponding kwargs for callable model.
Returns: a single model with corresponding pretrained weights.
Here’s an example loading resnet18
entrypoint from pytorch/vision
repo.
hub_model = hub.load(
'pytorch/vision:master', # repo_owner/repo_name:branch
'resnet18', # entrypoint
1234, # args for callable [not applicable to resnet]
pretrained=True) # kwargs for callable
Where are my downloaded model & weights saved?
The locations are used in the order of
- hub_dir: user specified path. It can be set in the following ways:
- Setting the environment variable
TORCH_HUB_DIR
- Callinghub.set_dir(<PATH_TO_HUB_DIR>)
~/.torch/hub
-
torch.hub.
set_dir
(d)[source] Optionally set hub_dir to a local dir to save downloaded models & weights.
If this argument is not set, env variable TORCH_HUB_DIR will be searched first, ~/.torch/hub will be created and used as fallback.
Parameters: d – path to a local folder to save downloaded models & weights.
Caching logic
By default, we don’t clean up files after loading it. Hub uses the cache by default if it already exists in hub_dir
.
Users can force a reload by calling hub.load(..., force_reload=True)
. This will delete
the existing github folder and downloaded weights, reinitialize a fresh download. This is useful
when updates are published to the same branch, users can keep up with the latest release.