# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.importosfromtypingimportOptionalimporttorchdef_get_local_rank()->Optional[int]:"""Function that gets the local rank from the environment. Returns: local_rank int or None if not set. """local_rank=os.environ.get("LOCAL_RANK")iflocal_rankisnotNone:local_rank=int(local_rank)returnlocal_rankdef_setup_cuda_device(device:torch.device)->torch.device:"""Function that sets the CUDA device and infers the cuda index if not set. Args: device (torch.device): The device to set. Raises: RuntimeError: If device index is not available. Returns: device """local_rank=_get_local_rank()or0ifdevice.indexisNone:device=torch.device(type="cuda",index=local_rank)# Ensure index is available before setting deviceifdevice.index>=torch.cuda.device_count():raiseRuntimeError("The local rank is larger than the number of available GPUs.")torch.cuda.set_device(device)returndevicedef_get_device_type_from_env()->str:"""Function that gets the torch.device based on the current machine. This currently only supports CPU, CUDA. Returns: device """iftorch.cuda.is_available():device="cuda"else:device="cpu"returndevicedef_validate_device_from_env(device:torch.device)->None:"""Function that validates the device is correct given the current machine. This will raise an error if the device is not available or doesn't match the assigned process device on distributed runs. Args: device (torch.device): The device to validate. Raises: RuntimeError: If the device is not available or doesn't match the assigned process device. Returns: device """local_rank=_get_local_rank()# Check if the device index is correctifdevice.type=="cuda"andlocal_rankisnotNone:# Ensure device index matches assigned index when distributed trainingifdevice.index!=local_rank:raiseRuntimeError(f"You can't specify a device index when using distributed training. \ Device specified is {device} but was assigned cuda:{local_rank}")# Check if the device is available on this machinetry:torch.empty(0,device=device)exceptRuntimeErrorase:raiseRuntimeError(f"The device {device} is not available on this machine.")frome
[docs]defget_device(device:Optional[str]=None)->torch.device:"""Function that takes an optional device string, verifies it's correct and available given the machine and distributed settings, and returns a torch.device. If device string is not provided, this function will infer the device based on the environment. If CUDA is available and being used, this function also sets the CUDA device. Args: device (Optional[str]): The name of the device to use. Returns: torch.device: device. """ifdeviceisNone:device=_get_device_type_from_env()device=torch.device(device)ifdevice.type=="cuda":device=_setup_cuda_device(device)_validate_device_from_env(device)returndevice
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.