Source code for torch.ao.ns.fx.utils
# mypy: allow-untyped-defs
import enum
import operator
import torch
import torch.nn as nn
import torch.ao.nn.intrinsic.quantized as nniq
import torch.ao.nn.quantized as nnq
toq = torch.ops.quantized
from typing import Tuple, Callable, Dict, Set, List, Optional, Union
from torch.fx import GraphModule
from torch.fx.graph import Node
from torch.ao.quantization import (
ObserverBase,
FakeQuantizeBase,
)
from torch.ao.quantization.utils import getattr_from_fqn
from torch.ao.quantization.observer import _is_activation_post_process
from .ns_types import NSNodeTargetType, NSResultsType
# TODO(future PR): consider deleting this enum and using the torch types
# directly. This might be tricky because it is not a one to one mapping.
class NodeInputOrOutputType(enum.Enum):
FP32 = enum.auto() # torch.float
INT8 = enum.auto() # torch.qint8 or torch.quint8
FP16 = enum.auto() # torch.float16
UNKNOWN = enum.auto() # we cannot determine input/output dtype
# TODO(future PR): while these functions can support multiple dtypes,
# for the purposes of numerical debugging we want to get the actual
# dtype used in the model. We will likely need some kind of dtype
# propagation to estimate this.
FP32_OR_INT8 = enum.auto() # either torch.float or torch.quint8 or torch.qint8
# TODO(future PRs): dynamic quant, fake quant, etc
def get_node_first_input_and_output_type(
node: Node,
gm: GraphModule,
logger_cls: Callable,
node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
) -> Tuple[NodeInputOrOutputType, NodeInputOrOutputType]:
# TODO(future PR): clean this up
FUNS_IO_TYPE_FP32 = node_type_to_io_type_map["funs_io_type_fp32"]
FUNS_IO_TYPE_FP16 = node_type_to_io_type_map["funs_io_type_fp16"]
FUNS_IO_TYPE_INT8 = node_type_to_io_type_map["funs_io_type_int8"]
FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["funs_io_type_fp32_or_int8"]
MODS_IO_TYPE_FP32 = node_type_to_io_type_map["mods_io_type_fp32"]
MODS_IO_TYPE_INT8 = node_type_to_io_type_map["mods_io_type_int8"]
MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["meths_io_type_fp32_or_int8"]
if node.op == "call_function":
if node.target in FUNS_IO_TYPE_FP32:
return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
if node.target in FUNS_IO_TYPE_FP16:
return (NodeInputOrOutputType.FP16, NodeInputOrOutputType.FP16)
elif node.target in FUNS_IO_TYPE_INT8:
return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
elif node.target in FUNS_IO_TYPE_FP32_OR_INT8:
first_arg = get_normalized_nth_input(node, gm, 0)
assert isinstance(first_arg, Node)
(
_prev_node_input_type,
prev_node_output_type,
) = get_node_first_input_and_output_type(
first_arg, gm, logger_cls, node_type_to_io_type_map
)
return (prev_node_output_type, prev_node_output_type)
else:
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
elif node.op == "call_module":
assert node.op == "call_module"
assert isinstance(node.target, str)
mod = getattr_from_fqn(gm, node.target)
is_known_fp32_or_int8_input_module = any(
isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type]
)
if (
isinstance(mod, (logger_cls, ObserverBase, FakeQuantizeBase)) # type: ignore[arg-type]
or is_known_fp32_or_int8_input_module
):
# A logger or observer's input and output type is the output
# type of the preceding node.
first_arg = get_normalized_nth_input(node, gm, 0)
assert isinstance(first_arg, Node)
(
_prev_node_input_type,
prev_node_output_type,
) = get_node_first_input_and_output_type(
first_arg, gm, logger_cls, node_type_to_io_type_map
)
return (prev_node_output_type, prev_node_output_type)
is_known_fp32_input_module = any(
isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32 # type: ignore[arg-type]
)
is_known_int8_input_module = any(
isinstance(mod, target_type) for target_type in MODS_IO_TYPE_INT8 # type: ignore[arg-type]
)
if is_known_fp32_input_module:
return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
elif is_known_int8_input_module:
return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
else:
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
elif node.op == "call_method":
if node.target == "dequantize":
# Dequantize is a special node because it allows multiple input types.
# So, we look up the output type of the previous node and return that
# as the input type of this node instance.
prev_node = get_normalized_nth_input(node, gm, 0)
assert isinstance(prev_node, Node)
(
_prev_node_input_type,
prev_node_output_type,
) = get_node_first_input_and_output_type(
prev_node, gm, logger_cls, node_type_to_io_type_map
)
return (prev_node_output_type, NodeInputOrOutputType.FP32)
elif node.target == "to":
# to is a special node because it allows multiple input types.
# So, we look up the output type of the previous node and return that
# as the input type of this node instance. We also look up the target
# of to and return the correct output type.
prev_node = get_normalized_nth_input(node, gm, 0)
assert isinstance(prev_node, Node)
(
_prev_node_input_type,
prev_node_output_type,
) = get_node_first_input_and_output_type(
prev_node, gm, logger_cls, node_type_to_io_type_map
)
cur_node_dtype_target = get_normalized_nth_input(node, gm, 1)
assert (
cur_node_dtype_target is torch.float16
), f"{cur_node_dtype_target} handling needs to be added"
return (prev_node_output_type, NodeInputOrOutputType.FP16)
elif node.target in METHS_IO_TYPE_FP32_OR_INT8:
first_arg = get_normalized_nth_input(node, gm, 0)
assert isinstance(first_arg, Node)
(
_prev_node_input_type,
prev_node_output_type,
) = get_node_first_input_and_output_type(
first_arg, gm, logger_cls, node_type_to_io_type_map
)
return (prev_node_output_type, prev_node_output_type)
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
else:
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
def get_node_input_qparams(
node: Node,
gm: GraphModule,
node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
) -> Optional[Tuple[Union[torch.Tensor, float], Union[torch.Tensor, int]]]:
"""
Returns the qparams (scale, zero_point) of the first input to `node`,
if they can be inferred from the graph.
"""
prev_node = get_normalized_nth_input(node, gm, 0)
if not isinstance(prev_node, Node):
return None
MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx):
scale_node = get_normalized_nth_input(node, gm, scale_arg_idx)
zp_node = get_normalized_nth_input(node, gm, zp_arg_idx)
assert isinstance(scale_node, Node) and isinstance(scale_node.target, str)
assert isinstance(zp_node, Node) and isinstance(zp_node.target, str)
scale_obj = getattr_from_fqn(gm, scale_node.target)
zp_obj = getattr_from_fqn(gm, zp_node.target)
return (scale_obj, zp_obj)
if prev_node.op == "call_function":
# quantize - read the args directly
if prev_node.target == torch.quantize_per_tensor:
return _get_scale_zp_from_function_args(prev_node, gm, 1, 2)
elif prev_node.target in (toq.add, toq.add_relu, toq.mul, toq.mul_relu):
return _get_scale_zp_from_function_args(prev_node, gm, 2, 3)
return None
# TODO(future PR): handle more functionals
# TODO(future PR): handle functional ops which inherit qparams from input
elif prev_node.op == "call_module":
# get type of the module
assert isinstance(prev_node.target, str)
module_obj = getattr_from_fqn(gm, prev_node.target)
if isinstance(
module_obj,
(
nnq.Linear,
nnq.Conv1d,
nnq.Conv2d,
nniq.ConvReLU2d,
nnq.Conv3d,
nnq.BatchNorm2d,
nnq.BatchNorm3d,
nnq.ConvTranspose1d,
nnq.ConvTranspose2d,
nnq.ELU,
nnq.GroupNorm,
nnq.InstanceNorm1d,
nnq.InstanceNorm2d,
nnq.InstanceNorm3d,
nnq.LayerNorm,
nnq.Hardswish,
nnq.LeakyReLU,
nnq.ReLU6,
nniq.BNReLU2d,
nniq.BNReLU3d,
nniq.ConvReLU1d,
nniq.ConvReLU2d,
nniq.ConvReLU3d,
nniq.LinearReLU,
),
):
return (module_obj.scale, module_obj.zero_point) # type: ignore[return-value]
is_known_fp32_or_int8_input_module = any(
isinstance(module_obj, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type]
)
if is_known_fp32_or_int8_input_module:
return get_node_input_qparams(prev_node, gm, node_type_to_io_type_map)
return None
def return_first_non_observer_node(
node: Node,
gm: GraphModule,
) -> Node:
"""
If node is not an observer, returns it. If node is an observer,
navigates up the graph and returns the first parent which is not an
observer. For example,
graph: (node_non_obs), node = node_non_obs : returns node_non_obs
graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs
graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs
"""
if node.op == "call_module":
node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type]
if _is_activation_post_process(node_obj):
assert len(node.args) == 1
assert isinstance(node.args[0], Node)
node = node.args[0]
# code duplication intended, not worth refactoring
assert isinstance(node.target, str)
node_obj = getattr_from_fqn(gm, node.target)
if _is_activation_post_process(node_obj):
assert len(node.args) == 1
assert isinstance(node.args[0], Node)
node = node.args[0]
return node
def get_number_of_non_param_args(
node: Node,
gm: GraphModule,
) -> int:
"""
Assumes that all non-param args occur first. Returns the number of
non-param args expected for a node. For example, for
F.linear(x, weight, bias)
Returns 1, because x is a non-param arg and weight and bias are params.
For
lstm_mod(x, hid)
Returns 2, because both x and hid are non-param args.
"""
if node.op == "call_module":
node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type]
if isinstance(node_obj, nn.LSTM):
return 2
# default is 1
return 1
def get_arg_indices_of_inputs_to_log(node: Node) -> List[int]:
"""
Returns the indices of args of the node which we should attach
loggers to, if input logging is enabled.
For example,
* for (x + y), returns [0, 1]
* for (1 + y), returns [1]
* for (x + 1), returns [0]
* for (linear(x, w, b)) returns [0]
* by default, returns [0]
"""
if len(node.args) == 0:
return []
if node.op == "call_function" and (
# TODO(future PR): use relationship map instead of hardcoding
node.target in (torch.add, torch.ops.quantized.add, operator.add)
or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul)
):
result = []
for i in range(2):
if type(node.args[i]) == Node:
result.append(i)
return result
return [0]
def get_target_type_str(node: Node, gm: GraphModule) -> str:
"""
Returns a string representation of the type of the function or module
pointed to by this node, or '' for other node types.
"""
target_type = ""
if node.op in ("call_function", "call_method"):
target_type = torch.typename(node.target)
elif node.op == "call_module":
assert isinstance(node.target, str)
target_mod = getattr_from_fqn(gm, node.target)
target_type = torch.typename(target_mod)
return target_type
def rekey_logger_info_on_node_name_of_model(
results: NSResultsType,
model_name: str,
) -> NSResultsType:
"""
Rekeys the layer name of a results dictionary to use node names
from `model_name`.
For example, transforms
{'base_op_1_0': {'node_output': {'model_a':
[{'ref_node_name': 'linear1', ...}]}}}
into
{'linear1': {'node_output': {'model_a':
[{'ref_node_name': 'linear1', ...}]}}}
Note: we cannot use these node names directly because they are not
guaranteed to be consistent across models. This is why we extract
the results first and rekey afterwards.
"""
new_results = {}
for old_layer_name, result_type_to_results in results.items():
new_layer_name = None
for model_name_to_results in result_type_to_results.values():
for cur_model_name, list_of_results in model_name_to_results.items():
if cur_model_name == model_name:
assert len(list_of_results)
new_layer_name = list_of_results[0]["ref_node_name"]
else:
continue
if new_layer_name is not None:
new_results[new_layer_name] = result_type_to_results
else:
new_results[old_layer_name] = result_type_to_results
return new_results
def maybe_add_missing_fqns(results: NSResultsType) -> None:
"""
If `fqn` entries are filled in for one of the models in `results`, copies
them over to any models which do not have them filled out.
A common use case benefitting from this is comparing a model prepared by
quantization to a quantized model. In this case, the model prepared by
quantization would have `fqn` entries, and the quantized model would not.
"""
# Check in the first result to find any model with fqn entries defined.
model_name_with_fqns = None
for result_type_to_results in results.values():
for model_name_to_results in result_type_to_results.values():
for model_name, model_results in model_name_to_results.items():
if len(model_results) > 0:
if model_results[0]["fqn"] is not None:
model_name_with_fqns = model_name
break
break
break
if model_name_with_fqns:
for result_type_to_results in results.values():
for model_name_to_results in result_type_to_results.values():
ref_model_results = model_name_to_results[model_name_with_fqns]
for model_name, model_results in model_name_to_results.items():
if model_name == model_name_with_fqns:
continue
for i in range(len(model_results)):
fqn = ref_model_results[i]["fqn"]
model_results[i]["fqn"] = fqn
def maybe_dequantize_first_two_tensor_args_and_handle_tuples(f):
def inner(*args, **kwargs):
a0, a1, *a_other = args
if (isinstance(a0, tuple) and isinstance(a1, tuple)) or (
isinstance(a0, list) and isinstance(a1, list)
):
results = []
for el0, el1 in zip(a0, a1):
new_args = (el0, el1, *a_other)
results.append(inner(*new_args, **kwargs))
return results
elif isinstance(a0, torch.Tensor) and isinstance(a1, torch.Tensor):
if a0.is_quantized:
a0 = a0.dequantize()
if a1.is_quantized:
a1 = a1.dequantize()
# for the purposes of this util, only handle floats
if a0.dtype != torch.float or a1.dtype != torch.float:
return None
new_args = (a0, a1, *a_other)
return f(*new_args, **kwargs)
return inner
[docs]@maybe_dequantize_first_two_tensor_args_and_handle_tuples
def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Computes the SQNR between `x` and `y`.
Args:
x: Tensor or tuple of tensors
y: Tensor or tuple of tensors
Return:
float or tuple of floats
"""
Ps = torch.norm(x)
Pn = torch.norm(x - y)
return 20 * torch.log10(Ps / Pn)
[docs]@maybe_dequantize_first_two_tensor_args_and_handle_tuples
def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Computes the normalized L2 error between `x` and `y`.
Args:
x: Tensor or tuple of tensors
y: Tensor or tuple of tensors
Return:
float or tuple of floats
"""
return torch.sqrt(((x - y) ** 2).sum() / (x ** 2).sum())
[docs]@maybe_dequantize_first_two_tensor_args_and_handle_tuples
def compute_cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Computes the cosine similarity between `x` and `y`.
Args:
x: Tensor or tuple of tensors
y: Tensor or tuple of tensors
Return:
float or tuple of floats
"""
# For convolutions, the shape of the quantized weight has one additional
# dimension compared to the shape of the fp32 weight. Match the shapes
# to enable cosine similarity comparison.
x = x.reshape(1, -1)
y = y.reshape(1, -1)
return torch.nn.functional.cosine_similarity(x, y)
def op_type_supports_shadowing(node: Node) -> bool:
if node.op == 'call_function':
if node.target in (torch.add, torch.mul, operator.add, operator.mul, torch.cat, torch.stack):
# shadowing for ops with multiple tensor inputs is not implemented yet
return False
return True
def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node:
"""
Given a node, gets the n'th input to that node, normalizing
args and kwargs to the best of its ability.
"""
try:
norm_args_and_kwargs = node.normalized_arguments(
gm, normalize_to_only_use_kwargs=True)
if norm_args_and_kwargs is not None:
norm_args, norm_kwargs = norm_args_and_kwargs
assert len(norm_args) + len(norm_kwargs) > idx
if idx < len(norm_args):
return norm_args[idx]
else:
# note: in Python 3.7+ dicts are ordered
return list(norm_kwargs.values())[idx]
else:
assert len(node.args) + len(node.kwargs) > idx
if idx < len(node.args):
return node.args[idx] # type: ignore[return-value]
else:
kwargs_idx = idx + len(node.args)
return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value]
except RuntimeError:
# this RuntimeError happens when node argument normalization
# requires typehints to proceed, such as for torch.add where
# either the first, second or both arguments could be tensors
assert len(node.args) + len(node.kwargs) > idx
if idx < len(node.args):
return node.args[idx] # type: ignore[return-value]
else:
kwargs_idx = idx + len(node.args)
return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value]