Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility.
Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights)
to a github repository by adding a simple
hubconf.py can have multiple entrypoints. Each entrypoint is defined as a python function
(example: a pre-trained model you want to publish).
def entrypoint_name(*args, **kwargs): # args & kwargs are optional, for models which take positional/keyword arguments. ...
How to implement an entrypoint?¶
Here is a code snippet specifies an entrypoint for
resnet18 model if we expand
the implementation in
In most case importing the right function in
hubconf.py is sufficient. Here we
just want to use the expanded version as an example to show how it works.
You can see the full script in
dependencies = ['torch'] from torchvision.models.resnet import resnet18 as _resnet18 # resnet18 is the name of entrypoint def resnet18(pretrained=False, **kwargs): """ # This docstring shows up in hub.help() Resnet18 model pretrained (bool): kwargs, load pretrained weights into the model """ # Call the model, load pretrained weights model = _resnet18(pretrained=pretrained, **kwargs) return model
dependenciesvariable is a list of package names required to load the model. Note this might be slightly different from dependencies required for training a model.
kwargsare passed along to the real callable function.
Docstring of the function works as a help message. It explains what does the model do and what are the allowed positional/keyword arguments. It’s highly recommended to add a few examples here.
Entrypoint function can either return a model(nn.module), or auxiliary tools to make the user workflow smoother, e.g. tokenizers.
Callables prefixed with underscore are considered as helper functions which won’t show up in
Pretrained weights can either be stored locally in the github repo, or loadable by
torch.hub.load_state_dict_from_url(). If less than 2GB, it’s recommended to attach it to a project release and use the url from the release. In the example above
pretrained, alternatively you can put the following logic in the entrypoint definition.
if pretrained: # For checkpoint saved in local github repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth dirname = os.path.dirname(__file__) checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>) state_dict = torch.load(checkpoint) model.load_state_dict(state_dict) # For checkpoint saved elsewhere checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
The published models should be at least in a branch/tag. It can’t be a random commit.
Loading models from Hub¶
Pytorch Hub provides convenient APIs to explore all available models in hub
torch.hub.list(), show docstring and examples through
torch.hub.help() and load the pre-trained models using
List all entrypoints available in github hubconf.
github (string) – a string with format “repo_owner/repo_name[:tag_name]” with an optional tag/branch. The default branch is master if not specified. Example: ‘pytorch/vision[:hub]’
force_reload (bool, optional) – whether to discard the existing cache and force a fresh download. Default is False.
a list of available entrypoint names
- Return type
>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
help(github, model, force_reload=False)¶
Show the docstring of entrypoint model.
github (string) – a string with format <repo_owner/repo_name[:tag_name]> with an optional tag/branch. The default branch is master if not specified. Example: ‘pytorch/vision[:hub]’
model (string) – a string of entrypoint name defined in repo’s hubconf.py
>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
load(repo_or_dir, model, *args, **kwargs)¶
Load a model from a github repo or a local directory.
Note: Loading a model is the typical use case, but this can also be used to for loading other objects such as tokenizers, loss functions, etc.
repo_or_diris expected to be of the form
repo_owner/repo_name[:tag_name]with an optional tag/branch.
repo_or_diris expected to be a path to a local directory.
repo_or_dir (string) – repo name (
source = 'github'; or a path to a local directory, if
source = 'local'.
model (string) – the name of a callable (entrypoint) defined in the repo/dir’s
*args (optional) – the corresponding args for callable
source (string, optional) –
'local'. Specifies how
repo_or_diris to be interpreted. Default is
force_reload (bool, optional) – whether to force a fresh download of the github repo unconditionally. Does not have any effect if
source = 'local'. Default is
verbose (bool, optional) – If
False, mute messages about hitting local caches. Note that the message about first download cannot be muted. Does not have any effect if
source = 'local'. Default is
**kwargs (optional) – the corresponding kwargs for callable
The output of the
modelcallable when called with the given
>>> # from a github repo >>> repo = 'pytorch/vision' >>> model = torch.hub.load(repo, 'resnet50', pretrained=True) >>> # from a local directory >>> path = '/some/local/path/pytorch/vision' >>> model = torch.hub.load(path, 'resnet50', pretrained=True)
download_url_to_file(url, dst, hash_prefix=None, progress=True)¶
Download object at the given URL to a local path.
url (string) – URL of the object to download
dst (string) – Full path where object will be saved, e.g. /tmp/temporary_file
hash_prefix (string, optional) – If not None, the SHA256 downloaded file should start with hash_prefix. Default: None
progress (bool, optional) – whether or not to display a progress bar to stderr Default: True
>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None)¶
Loads the Torch serialized object at the given URL.
If downloaded file is a zip file, it will be automatically decompressed.
If the object is already present in model_dir, it’s deserialized and returned. The default value of model_dir is
<hub_dir>/checkpointswhere hub_dir is the directory returned by
url (string) – URL of the object to download
model_dir (string, optional) – directory in which to save the object
map_location (optional) – a function or a dict specifying how to remap storage locations (see torch.load)
progress (bool, optional) – whether or not to display a progress bar to stderr. Default: True
check_hash (bool, optional) – If True, the filename part of the URL should follow the naming convention
<sha256>is the first eight or more digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: False
file_name (string, optional) – name for the downloaded file. Filename from url will be used if not set.
>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
Running a loaded model:¶
torch.hub.load() are used to
instantiate a model. After you have loaded a model, how can you find out
what you can do with the model?
A suggested workflow is
dir(model)to see all available methods of the model.
help(model.foo)to check what arguments
model.footakes to run
To help users explore without referring to documentation back and forth, we strongly recommend repo owners make function help messages clear and succinct. It’s also helpful to include a minimal working example.
Where are my downloaded models saved?¶
The locations are used in the order of
$TORCH_HOME/hub, if environment variable
$XDG_CACHE_HOME/torch/hub, if environment variable
Get the Torch Hub cache directory used for storing downloaded models & weights.
set_dir()is not called, default path is
$TORCH_HOME/hubwhere environment variable
$XDG_CACHE_HOMEfollows the X Design Group specification of the Linux filesystem layout, with a default value
~/.cacheif the environment variable is not set.
Optionally set the Torch Hub directory used to save downloaded models & weights.
d (string) – path to a local folder to save downloaded models & weights.
By default, we don’t clean up files after loading it. Hub uses the cache by default if it already exists in the
directory returned by
Users can force a reload by calling
hub.load(..., force_reload=True). This will delete
the existing github folder and downloaded weights, reinitialize a fresh download. This is useful
when updates are published to the same branch, users can keep up with the latest release.
Torch hub works by importing the package as if it was installed. There’re some side effects
introduced by importing in Python. For example, you can see new items in Python caches
sys.path_importer_cache which is normal Python behavior.
A known limitation that worth mentioning here is user CANNOT load two different branches of the same repo in the same python process. It’s just like installing two packages with the same name in Python, which is not good. Cache might join the party and give you surprises if you actually try that. Of course it’s totally fine to load them in separate processes.