Shortcuts

implement_for

class torchrl._utils.implement_for(module_name: Union[str, Callable], from_version: Optional[str] = None, to_version: Optional[str] = None, *, class_method: bool = False, compilable: bool = False)[source]

A version decorator that checks the version in the environment and implements a function with the fitting one.

If specified module is missing or there is no fitting implementation, call of the decorated function will lead to the explicit error. In case of intersected ranges, last fitting implementation is used.

This wrapper also works to implement different backends for a same function (eg. gym vs gymnasium, numpy vs jax-numpy etc).

Parameters:
  • module_name (str or callable) – version is checked for the module with this name (e.g. “gym”). If a callable is provided, it should return the module.

  • from_version – version from which implementation is compatible. Can be open (None).

  • to_version – version from which implementation is no longer compatible. Can be open (None).

Keyword Arguments:
  • class_method (bool, optional) – if True, the function will be written as a class method. Defaults to False.

  • compilable (bool, optional) – If False, the module import happens only on the first call to the wrapped function. If True, the module import happens when the wrapped function is initialized. This allows the wrapped function to work well with torch.compile. Defaults to False.

Examples

>>> @implement_for("gym", "0.13", "0.14")
>>> def fun(self, x):
...     # Older gym versions will return x + 1
...     return x + 1
...
>>> @implement_for("gym", "0.14", "0.23")
>>> def fun(self, x):
...     # More recent gym versions will return x + 2
...     return x + 2
...
>>> @implement_for(lambda: import_module("gym"), "0.23", None)
>>> def fun(self, x):
...     # More recent gym versions will return x + 2
...     return x + 2
...
>>> @implement_for("gymnasium", None, "1.0.0")
>>> def fun(self, x):
...     # If gymnasium is to be used instead of gym, x+3 will be returned
...     return x + 3
...

This indicates that the function is compatible with gym 0.13+, but doesn’t with gym 0.14+.

static get_class_that_defined_method(f)[source]

Returns the class of a method, if it is defined, and None otherwise.

classmethod import_module(module_name: Union[Callable, str]) str[source]

Imports module and returns its version.

module_set()[source]

Sets the function in its module, if it exists already.

classmethod reset(setters_dict: Optional[Dict[str, implement_for]] = None)[source]

Resets the setters in setter_dict.

setter_dict is a copy of implementations. We just need to iterate through its values and call module_set() for each.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources