[docs]defparameters_to_vector(parameters:Iterable[torch.Tensor])->torch.Tensor:r"""Flatten an iterable of parameters into a single vector. Args: parameters (Iterable[Tensor]): an iterable of Tensors that are the parameters of a model. Returns: The parameters represented by a single vector """# Flag for the device where the parameter is locatedparam_device=Nonevec=[]forparaminparameters:# Ensure the parameters are located in the same deviceparam_device=_check_param_device(param,param_device)vec.append(param.view(-1))returntorch.cat(vec)
[docs]defvector_to_parameters(vec:torch.Tensor,parameters:Iterable[torch.Tensor])->None:r"""Copy slices of a vector into an iterable of parameters. Args: vec (Tensor): a single vector representing the parameters of a model. parameters (Iterable[Tensor]): an iterable of Tensors that are the parameters of a model. """# Ensure vec of type Tensorifnotisinstance(vec,torch.Tensor):raiseTypeError(f'expected torch.Tensor, but got: {torch.typename(vec)}')# Flag for the device where the parameter is locatedparam_device=None# Pointer for slicing the vector for each parameterpointer=0forparaminparameters:# Ensure the parameters are located in the same deviceparam_device=_check_param_device(param,param_device)# The length of the parameternum_param=param.numel()# Slice the vector, reshape it, and replace the old data of the parameterparam.data=vec[pointer:pointer+num_param].view_as(param).data# Increment the pointerpointer+=num_param
def_check_param_device(param:torch.Tensor,old_param_device:Optional[int])->int:r"""Check if the parameters are located on the same device. Currently, the conversion between model parameters and single vector form is not supported for multiple allocations, e.g. parameters in different GPUs/PrivateUse1s, or mixture of CPU/GPU/PrivateUse1. Args: param ([Tensor]): a Tensor of a parameter of a model old_param_device (int): the device where the first parameter of a model is allocated. Returns: old_param_device (int): report device for the first time """# Meet the first parametersupport_device_types=["cuda",torch._C._get_privateuse1_backend_name()]ifold_param_deviceisNone:old_param_device=param.get_device()ifparam.device.typeinsupport_device_typeselse-1else:warn=Falseifparam.device.typeinsupport_device_types:# Check if in same GPU/PrivateUse1warn=(param.get_device()!=old_param_device)else:# Check if in CPUwarn=(old_param_device!=-1)ifwarn:raiseTypeError('Found two parameters on different devices, ''this is currently not supported.')returnold_param_device
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.