# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.importloggingimportosimportrandomfromtypingimportOptional,Unionimportnumpyasnpimporttorchfromtorchtune.training._distributedimport_broadcast_tensor,get_world_size_and_rankfromtorchtune.utilsimportget_logger_log:logging.Logger=get_logger()
[docs]defset_seed(seed:Optional[int]=None,debug_mode:Optional[Union[str,int]]=None)->int:"""Function that sets seed for pseudo-random number generators across commonly used libraries. This seeds PyTorch, NumPy, and the python.random module. For distributed jobs, each local process sets its own seed, computed seed + rank. For more details, see https://pytorch.org/docs/stable/notes/randomness.html. Args: seed (Optional[int]): the integer value seed. If `None`, a random seed will be generated and set. debug_mode (Optional[Union[str, int]]): Controls debug_mode settings for deterministic operations within PyTorch. * If `None`, don't set any PyTorch global values. * If "default" or 0, don't error or warn on nondeterministic operations and additionally enable PyTorch CuDNN benchmark. * If "warn" or 1, warn on nondeterministic operations and disable PyTorch CuDNN benchmark. * If "error" or 2, error on nondeterministic operations and disable PyTorch CuDNN benchmark. * For more details, see :func:`torch.set_deterministic_debug_mode` and https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms. Returns: int: the current seed Raises: ValueError: If the input seed value is outside the required range. """world_size,rank=get_world_size_and_rank()max_val=np.iinfo(np.uint32).max-world_size+1min_val=np.iinfo(np.uint32).minifseedisNone:rand_seed=torch.randint(min_val,max_val,(1,))seed=_broadcast_tensor(rand_seed,0).item()# sync seed across ranksifseed<min_valorseed>max_val:raiseValueError(f"Invalid seed value provided: {seed}. Value must be in the range [{min_val}, {max_val}]")local_seed=seed+rankifrank==0:_log.debug(f"Setting manual seed to local seed {local_seed}. Local seed is seed + rank = {seed} + {rank}")torch.manual_seed(local_seed)np.random.seed(local_seed)random.seed(local_seed)ifdebug_modeisnotNone:_log.debug(f"Setting deterministic debug mode to {debug_mode}")torch.set_deterministic_debug_mode(debug_mode)deterministic_debug_mode=torch.get_deterministic_debug_mode()ifdeterministic_debug_mode==0:_log.debug("Disabling cuDNN deterministic mode")torch.backends.cudnn.deterministic=Falsetorch.backends.cudnn.benchmark=Trueelse:_log.debug("Enabling cuDNN deterministic mode")torch.backends.cudnn.deterministic=Truetorch.backends.cudnn.benchmark=False# reference: https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibilityos.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"returnseed
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.