[docs]defset_rng_state(new_state:torch.Tensor)->None:r"""Sets the random number generator state. .. note:: This function only works for CPU. For CUDA, please use :func:`torch.manual_seed`, which works for both CPU and CUDA. Args: new_state (torch.ByteTensor): The desired state """default_generator.set_state(new_state)
[docs]defget_rng_state()->torch.Tensor:r"""Returns the random number generator state as a `torch.ByteTensor`. .. note:: The returned state is for the default generator on CPU only. See also: :func:`torch.random.fork_rng`. """returndefault_generator.get_state()
[docs]defmanual_seed(seed)->torch._C.Generator:r"""Sets the seed for generating random numbers on all devices. Returns a `torch.Generator` object. Args: seed (int): The desired seed. Value must be within the inclusive range `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError is raised. Negative inputs are remapped to positive values with the formula `0xffff_ffff_ffff_ffff + seed`. """seed=int(seed)importtorch.cudaifnottorch.cuda._is_in_bad_fork():torch.cuda.manual_seed_all(seed)importtorch.mpsifnottorch.mps._is_in_bad_fork():torch.mps.manual_seed(seed)importtorch.xpuifnottorch.xpu._is_in_bad_fork():torch.xpu.manual_seed_all(seed)_seed_custom_device(seed)returndefault_generator.manual_seed(seed)
[docs]defseed()->int:r"""Sets the seed for generating random numbers to a non-deterministic random number on all devices. Returns a 64 bit number used to seed the RNG. """seed=default_generator.seed()importtorch.cudaifnottorch.cuda._is_in_bad_fork():torch.cuda.manual_seed_all(seed)importtorch.mpsifnottorch.mps._is_in_bad_fork():torch.mps.manual_seed(seed)importtorch.xpuifnottorch.xpu._is_in_bad_fork():torch.xpu.manual_seed_all(seed)_seed_custom_device(seed)returnseed
def_seed_custom_device(seed)->None:r"""Sets the seed to generate random numbers for custom device. Args: seed (int): The desired seed. See [Note: support the custom device with privateuse1] """seed=int(seed)custom_backend_name=torch._C._get_privateuse1_backend_name()ifhasattr(torch,custom_backend_name):custom_device_mod=getattr(torch,custom_backend_name)_bad_fork_name="_is_in_bad_fork"_seed_all_name="manual_seed_all"ifhasattr(custom_device_mod,_bad_fork_name)andhasattr(custom_device_mod,_seed_all_name):ifnotgetattr(custom_device_mod,_bad_fork_name)():getattr(custom_device_mod,_seed_all_name)(seed)else:message=f"Set seed for `{custom_backend_name}` device does not take effect, please add API's "message+=f"`{_bad_fork_name}` and `{_seed_all_name}` to `{custom_backend_name}` device module."warnings.warn(message,UserWarning,stacklevel=3)
[docs]definitial_seed()->int:r"""Returns the initial seed for generating random numbers as a Python `long`. .. note:: The returned seed is for the default generator on CPU only. """returndefault_generator.initial_seed()
_fork_rng_warned_already=False
[docs]@contextlib.contextmanagerdeffork_rng(devices=None,enabled=True,_caller="fork_rng",_devices_kw="devices",device_type="cuda",)->Generator:""" Forks the RNG, so that when you return, the RNG is reset to the state that it was previously in. Args: devices (iterable of Device IDs): devices for which to fork the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates on all devices, but will emit a warning if your machine has a lot of devices, since this function will run very slowly in that case. If you explicitly specify devices, this warning will be suppressed enabled (bool): if ``False``, the RNG is not forked. This is a convenience argument for easily disabling the context manager without having to delete it and unindent your Python code under it. device_type (str): device type str, default is `cuda`. As for custom device, see details in [Note: support the custom device with privateuse1] """ifdevice_type=="meta":yieldreturndevice_type=torch.device(device_type).typedevice_mod=getattr(torch,device_type,None)ifdevice_modisNone:raiseRuntimeError(f"torch has no module of `{device_type}`, you should register "+"a module by `torch._register_device_module`.")global_fork_rng_warned_already# Internal arguments:# _caller: the function which called fork_rng, which the user used# _devices_kw: the devices keyword of _callerifnotenabled:yieldreturnifdevicesisNone:num_devices=device_mod.device_count()ifnum_devices>1andnot_fork_rng_warned_already:message=(f"{device_type.upper()} reports that you have {num_devices} available devices, and "f"you have used {_caller} without explicitly specifying which devices are being used. "f"For safety, we initialize *every* {device_type.upper()} device by default, which can "f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only"f" making use of a few {device_type.upper()} devices, set the environment variable "f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} ""with the set of devices you are actually using. For example, if you are using CPU only, ""set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, "f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices "f"and suppress this warning, set the '{_devices_kw}' keyword argument to "f"`range(torch.{device_type}.device_count())`.")warnings.warn(message)_fork_rng_warned_already=Truedevices=list(range(num_devices))else:# Protect against user passing us a generator; we need to traverse this# multiple times but a generator will be exhausted upon first traversaldevices=list(devices)cpu_rng_state=torch.get_rng_state()device_rng_states=[device_mod.get_rng_state(device)fordeviceindevices]try:yieldfinally:torch.set_rng_state(cpu_rng_state)fordevice,device_rng_stateinzip(devices,device_rng_states):device_mod.set_rng_state(device_rng_state,device)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.