implement_for¶
- class torchrl._utils.implement_for(module_name: Union[str, Callable], from_version: Optional[str] = None, to_version: Optional[str] = None, *, class_method: bool = False)[source]¶
A version decorator that checks the version in the environment and implements a function with the fitting one.
If specified module is missing or there is no fitting implementation, call of the decorated function will lead to the explicit error. In case of intersected ranges, last fitting implementation is used.
This wrapper also works to implement different backends for a same function (eg. gym vs gymnasium, numpy vs jax-numpy etc).
- Parameters:
module_name (str or callable) – version is checked for the module with this name (e.g. “gym”). If a callable is provided, it should return the module.
from_version – version from which implementation is compatible. Can be open (None).
to_version – version from which implementation is no longer compatible. Can be open (None).
- Keyword Arguments:
class_method (bool, optional) – if
True
, the function will be written as a class method. Defaults toFalse
.
Examples
>>> @implement_for("gym", "0.13", "0.14") >>> def fun(self, x): ... # Older gym versions will return x + 1 ... return x + 1 ... >>> @implement_for("gym", "0.14", "0.23") >>> def fun(self, x): ... # More recent gym versions will return x + 2 ... return x + 2 ... >>> @implement_for(lambda: import_module("gym"), "0.23", None) >>> def fun(self, x): ... # More recent gym versions will return x + 2 ... return x + 2 ... >>> @implement_for("gymnasium") >>> def fun(self, x): ... # If gymnasium is to be used instead of gym, x+3 will be returned ... return x + 3 ...
This indicates that the function is compatible with gym 0.13+, but doesn’t with gym 0.14+.
- static get_class_that_defined_method(f)[source]¶
Returns the class of a method, if it is defined, and None otherwise.
- classmethod import_module(module_name: Union[Callable, str]) str [source]¶
Imports module and returns its version.
- classmethod reset(setters_dict: Optional[Dict[str, implement_for]] = None)[source]¶
Resets the setters in setter_dict.
setter_dict
is a copy of implementations. We just need to iterate through its values and callmodule_set()
for each.