Shortcuts

Source code for torch.export.passes

from typing import Dict, Union

import torch
import torch.utils._pytree as pytree
from torch.export.exported_program import ExportedProgram


__all__ = ["move_to_device_pass"]


[docs]def move_to_device_pass( ep: ExportedProgram, location: Union[torch.device, str, Dict[str, str]] ) -> ExportedProgram: """ Move the exported program to the given device. Args: ep (ExportedProgram): The exported program to move. location (Union[torch.device, str, Dict[str, str]]): The device to move the exported program to. If a string, it is interpreted as a device name. If a dict, it is interpreted as a mapping from the existing device to the intended one Returns: ExportedProgram: The moved exported program. """ def _get_new_device( curr_device: torch.device, location: Union[torch.device, str, Dict[str, str]], ) -> str: if isinstance(location, dict): if str(curr_device) in location.keys(): return location[str(curr_device)] else: return str(curr_device) else: return str(location) # move all the state_dict for k, v in ep.state_dict.items(): if isinstance(v, torch.nn.Parameter): ep._state_dict[k] = torch.nn.Parameter( v.to(_get_new_device(v.device, location)) ) else: ep._state_dict[k] = v.to(_get_new_device(v.device, location)) # move all the constants for k, v in ep.constants.items(): if isinstance(v, torch.Tensor): ep._constants[k] = v.to(_get_new_device(v.device, location)) for node in ep.graph.nodes: # move all the nodes kwargs with burnt-in device if "device" in node.kwargs: kwargs = node.kwargs.copy() kwargs["device"] = _get_new_device(kwargs["device"], location) node.kwargs = kwargs # move all the tensor metadata node.meta["val"] = pytree.tree_map( lambda v: v.to(_get_new_device(v.device, location)) if isinstance(v, torch.Tensor) else v, node.meta.get("val"), ) ep.validate() return ep

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources