Source code for torch.export.passes
from typing import Dict, Union
import torch
import torch.utils._pytree as pytree
from torch.export.exported_program import ExportedProgram
__all__ = ["move_to_device_pass"]
[docs]def move_to_device_pass(
ep: ExportedProgram, location: Union[torch.device, str, Dict[str, str]]
) -> ExportedProgram:
"""
Move the exported program to the given device.
Args:
ep (ExportedProgram): The exported program to move.
location (Union[torch.device, str, Dict[str, str]]): The device to move the exported program to.
If a string, it is interpreted as a device name.
If a dict, it is interpreted as a mapping from
the existing device to the intended one
Returns:
ExportedProgram: The moved exported program.
"""
def _get_new_device(
curr_device: torch.device,
location: Union[torch.device, str, Dict[str, str]],
) -> str:
if isinstance(location, dict):
if str(curr_device) in location.keys():
return location[str(curr_device)]
else:
return str(curr_device)
else:
return str(location)
# move all the state_dict
for k, v in ep.state_dict.items():
if isinstance(v, torch.nn.Parameter):
ep._state_dict[k] = torch.nn.Parameter(
v.to(_get_new_device(v.device, location))
)
else:
ep._state_dict[k] = v.to(_get_new_device(v.device, location))
# move all the constants
for k, v in ep.constants.items():
if isinstance(v, torch.Tensor):
ep._constants[k] = v.to(_get_new_device(v.device, location))
for node in ep.graph.nodes:
# move all the nodes kwargs with burnt-in device
if "device" in node.kwargs:
kwargs = node.kwargs.copy()
kwargs["device"] = _get_new_device(kwargs["device"], location)
node.kwargs = kwargs
# move all the tensor metadata
node.meta["val"] = pytree.tree_map(
lambda v: v.to(_get_new_device(v.device, location))
if isinstance(v, torch.Tensor)
else v,
node.meta.get("val"),
)
ep.validate()
return ep