Shortcuts

PyTorch/XLA API

torch_xla

torch_xla.device(index: Optional[int] = None) device[source]

Returns a given instance of an XLA device.

If SPMD enables, returns a virtual device that wraps all devices available to this process.

Parameters

index – index of the XLA device to be returned. Corresponds to index in torch_xla.devices().

Returns

An XLA torch.device.

torch_xla.devices() List[device][source]

Returns all devices available in the current process.

Returns

A list of XLA torch.devices.

torch_xla.device_count() int[source]

Returns number of addressable devices in the current process.

torch_xla.sync(wait: bool = False)[source]

Launches all pending graph operations.

Parameters

wait (bool) – whether to block the current process until the execution finished.

torch_xla.compile(f: Optional[Callable] = None, full_graph: Optional[bool] = False, name: Optional[str] = None, num_different_graphs_allowed: Optional[int] = None)[source]

Optimizes given model/function using torch_xla’s LazyTensor tracing mode. PyTorch/XLA will trace the given function with given inputs and then generate graphs to represent the pytorch operations happens within this function. This graph will be compiled by the XLA and executed on the accelerator(decided by the tensor’s device). Eager mode will be disabled for the compiled region of the funciton.

Parameters
  • model (Callable) – Module/function to optimize, if not passed this function will act as a context manager.

  • full_graph (Optional[bool]) – Whether this compile should generate a single graph. If set to True and multiple graphs will be generated torch_xla will throw an error with debug info and exit.

  • name (Optional[name]) – Name of the compiled program. The name of the function f will be used if not specified. This name will be used in the PT_XLA_DEBUG messages as well as HLO/IR dump file.

  • num_different_graphs_allowed (Optional[python:int]) – number of different traced graphs of the given model/function that we are allowed to have. An error will be raised in case this limit is exceeded.

Example:

# usage 1
@torch_xla.compile()
def foo(x):
  return torch.sin(x) + torch.cos(x)

def foo2(x):
  return torch.sin(x) + torch.cos(x)
# usage 2
compiled_foo2 = torch_xla.compile(foo2)

# usage 3
with torch_xla.compile():
  res = foo2(x)
torch_xla.manual_seed(seed, device=None)[source]

Set the seed for generating random numbers for the current XLA device.

Parameters
  • seed (python:integer) – The state to be set.

  • device (torch.device, optional) – The device where the RNG state needs to be set. If missing the default device seed will be set.

runtime

torch_xla.runtime.device_type() Optional[str][source]

Returns the current PjRt device type.

Selects a default device if none has been configured

Returns

A string representation of the device.

torch_xla.runtime.local_process_count() int[source]

Returns the number of processes running on this host.

torch_xla.runtime.local_device_count() int[source]

Returns the total number of devices on this host.

Assumes each process has the same number of addressable devices.

torch_xla.runtime.addressable_device_count() int[source]

Returns the number of devices visible to this process.

torch_xla.runtime.global_device_count() int[source]

Returns the total number of devices across all processes/hosts.

torch_xla.runtime.global_runtime_device_count() int[source]

Returns the total number of runtime devices across all processes/hosts, especially useful for SPMD.

torch_xla.runtime.world_size() int[source]

Returns the total number of processes participating in the job.

torch_xla.runtime.global_ordinal() int[source]

Returns global ordinal of this thread within all processes.

Global ordinal is in range [0, global_device_count). Global ordinals are not guaranteed to have any predictable relationship to the TPU worker ID nor are they guaranteed to be contiguous on each host.

torch_xla.runtime.local_ordinal() int[source]

Returns local ordinal of this thread within this host.

Local ordinal is in range [0, local_device_count).

torch_xla.runtime.get_master_ip() str[source]

Retrieve the master worker IP for the runtime. This calls into backend-specific discovery APIs.

Returns

master worker’s IP address as a string.

torch_xla.runtime.use_spmd(auto: Optional[bool] = False)[source]

API to enable SPMD mode. This is a recommended way to enable SPMD.

This forces SPMD mode if some tensors are already initialized on non-SPMD devices. This means that those tensors would be replicated across the devices.

Parameters

auto (bool) – Whether to enable the auto-sharding. Read https://github.com/pytorch/xla/blob/master/docs/spmd_advanced.md#auto-sharding for more detail

torch_xla.runtime.is_spmd()[source]

Returns if SPMD is set for execution.

torch_xla.runtime.initialize_cache(path: str, readonly: bool = False)[source]

Initializes the persistent compilation cache. This API must be called before any computations have been performed.

Parameters
  • path (str) – The path at which to store the persistent cache.

  • readonly (bool) – Whether or not this worker should have write access to the cache.

xla_model

torch_xla.core.xla_model.xla_device(n: Optional[int] = None, devkind: Optional[str] = None) device[source]

Returns a given instance of an XLA device.

Parameters
  • n (python:int, optional) – The specific instance (ordinal) to be returned. If specified, the specific XLA device instance will be returned. Otherwise the first device of devkind will be returned.

  • devkind (string..., optional) – If specified, device type such as TPU, CUDA, CPU, or custom PJRT device. Deprecated.

Returns

A torch.device with the requested instance.

torch_xla.core.xla_model.xla_device_hw(device: Union[str, device]) str[source]

Returns the hardware type of the given device.

Parameters

device (string or torch.device) – The xla device that will be mapped to the real device.

Returns

A string representation of the hardware type of the given device.

torch_xla.core.xla_model.is_master_ordinal(local: bool = True) bool[source]

Checks whether the current process is the master ordinal (0).

Parameters

local (bool) – Whether the local or global master ordinal should be checked. In case of multi-host replication, there is only one global master ordinal (host 0, device 0), while there are NUM_HOSTS local master ordinals. Default: True

Returns

A boolean indicating whether the current process is the master ordinal.

torch_xla.core.xla_model.all_reduce(reduce_type: str, inputs: Union[Tensor, List[Tensor]], scale: float = 1.0, groups: Optional[List[List[int]]] = None, pin_layout: bool = True) Union[Tensor, List[Tensor]][source]

Performs an inplace reduce operation on the input tensor(s).

Parameters
  • reduce_type (string) – One of xm.REDUCE_SUM, xm.REDUCE_MUL, xm.REDUCE_AND, xm.REDUCE_OR, xm.REDUCE_MIN and xm.REDUCE_MAX.

  • inputs – Either a single torch.Tensor or a list of torch.Tensor to perform the all reduce op to.

  • scale (python:float) – A default scaling value to be applied after the reduce. Default: 1.0

  • groups (list, optional) –

    A list of list, representing the replica groups for the all_reduce() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]

    defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.

  • pin_layout (bool, optional) – whether to pin the layout for this communication op. Layout pining can prevent potential data corruption when each process that participate in the communication has slightly different program, but it might cause some xla compilation to fail. Unpin the layout when you see error message like “HloModule has a mix of layout constrained”.

Returns

If a single torch.Tensor is passed, the return value is a torch.Tensor holding the reduced value (across the replicas). If a list/tuple is passed, this function performs an inplace all-reduce op on the input tensors, and returns the list/tuple itself.

torch_xla.core.xla_model.all_gather(value: Tensor, dim: int = 0, groups: Optional[List[List[int]]] = None, output: Optional[Tensor] = None, pin_layout: bool = True) Tensor[source]

Performs an all-gather operation along a given dimension.

Parameters
  • value (torch.Tensor) – The input tensor.

  • dim (python:int) – The gather dimension. Default: 0

  • groups (list, optional) –

    A list of list, representing the replica groups for the all_gather() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]

    defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.

  • output (torch.Tensor) – Optional output tensor.

  • pin_layout (bool, optional) – whether to pin the layout for this communication op. Layout pining can prevent potential data corruption when each process that participate in the communication has slightly different program, but it might cause some xla compilation to fail. Unpin the layout when you see error message like “HloModule has a mix of layout constrained”.

Returns

A tensor which has, in the dim dimension, all the values from the participating replicas.

torch_xla.core.xla_model.all_to_all(value: Tensor, split_dimension: int, concat_dimension: int, split_count: int, groups: Optional[List[List[int]]] = None, pin_layout: bool = True) Tensor[source]

Performs an XLA AllToAll() operation on the input tensor.

See: https://www.tensorflow.org/xla/operation_semantics#alltoall

Parameters
  • value (torch.Tensor) – The input tensor.

  • split_dimension (python:int) – The dimension upon which the split should happen.

  • concat_dimension (python:int) – The dimension upon which the concat should happen.

  • split_count (python:int) – The split count.

  • groups (list, optional) –

    A list of list, representing the replica groups for the all_reduce() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]

    defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.

  • pin_layout (bool, optional) – whether to pin the layout for this communication op. Layout pining can prevent potential data corruption when each process that participate in the communication has slightly different program, but it might cause some xla compilation to fail. Unpin the layout when you see error message like “HloModule has a mix of layout constrained”.

Returns

The result torch.Tensor of the all_to_all() operation.

torch_xla.core.xla_model.add_step_closure(closure: Callable[[...], Any], args: Tuple[Any, ...] = (), run_async: bool = False)[source]

Adds a closure to the list of the ones to be run at the end of the step.

Many times during model training there is the need to print/report (print to console, post to tensorboard, etc…) information which require the content of intermediary tensors to be inspected. Inspecting different tensors content in different points of the model code requires many executions and typically causes performance issues. Adding a step closure will ensure that it will be run after the barrier, when all the live tensors will be already materialized to device data. Live tensors which will include the ones captured by the closure arguments. So using add_step_closure() will ensure a single execution will be performed, even when multiple closures are queued, requiring multiple tensors to be inspected. Step closures will be run sequentially in the order they have been queued. Note that even though using this API the execution will be optimized, it is advised to throttle the printing/reporting events once every N steps.

Parameters
  • closure (callable) – The function to be called.

  • args (tuple) – The arguments to be passed to the closure.

  • run_async – If True, run the closure asynchronously.

torch_xla.core.xla_model.wait_device_ops(devices: List[str] = [])[source]

Waits for all the async operations on the given devices to complete.

Parameters

devices (string..., optional) – The devices whose async ops need to be waited for. If empty, all the local devices will be waited for.

torch_xla.core.xla_model.optimizer_step(optimizer: Optimizer, barrier: bool = False, optimizer_args: Dict = {}, groups: Optional[List[List[int]]] = None, pin_layout: bool = True)[source]

Run the provided optimizer step and sync gradidents across all devices.

Parameters
  • optimizer (torch.Optimizer) – The torch.Optimizer instance whose step() function needs to be called. The step() function will be called with the optimizer_args named arguments.

  • barrier (bool, optional) – Whether the XLA tensor barrier should be issued in this API. If using the PyTorch XLA ParallelLoader or DataParallel support, this is not necessary as the barrier will be issued by the XLA data loader iterator next() call. Default: False

  • optimizer_args (dict, optional) – Named arguments dictionary for the optimizer.step() call.

  • groups (list, optional) –

    A list of list, representing the replica groups for the all_reduce() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]

    defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.

  • pin_layout (bool, optional) – whether to pin the layout when reducing gradients. See xm.all_reduce for details.

Returns

The same value returned by the optimizer.step() call.

Example

>>> import torch_xla.core.xla_model as xm
>>> xm.optimizer_step(self.optimizer)
torch_xla.core.xla_model.save(data: Any, file_or_path: Union[str, TextIO], master_only: bool = True, global_master: bool = False)[source]

Saves the input data into a file.

The saved data is transferred to PyTorch CPU device before being saved, so a following torch.load() will load CPU data. Care must be taken when working with views. Instead of saving views it’s recommended that you recreate them after the tensors have been loaded and moved to their destination device(s).

Parameters
  • data – The input data to be saved. Any nested combination of Python objects (list, tuples, sets, dicts, …).

  • file_or_path – The destination for the data saving operation. Either a file path or a Python file object. If master_only is False the path or file objects must point to different destinations as otherwise all the writes from the same host will override each other.

  • master_only (bool, optional) – Whether only the master device should save the data. If False, the file_or_path argument should be a different file or path for each of the ordinals taking part to the replication, otherwise all the replicas on the same host will be writing to the same location. Default: True

  • global_master (bool, optional) – When master_only is True this flag controls whether every host’s master (if global_master is False) saves the content, or only the global master (ordinal 0). Default: False

Example

>>> import torch_xla.core.xla_model as xm
>>> xm.wait_device_ops() # wait for all pending operations to finish.
>>> xm.save(obj_to_save, path_to_save)
>>> xm.rendezvous('torch_xla.core.xla_model.save') # multi process context only
torch_xla.core.xla_model.rendezvous(tag: str, payload: bytes = b'', replicas: List[int] = []) List[bytes][source]

Waits for all the mesh clients to reach the named rendezvous.

Note: PJRT does not support the XRT mesh server, so this is effectively an alias to xla_rendezvous.

Parameters
  • tag (string) – The name of the rendezvous to join.

  • payload (bytes, optional) – The payload to be sent to the rendezvous.

  • replicas (list, python:int) – The replica ordinals taking part of the rendezvous. Empty means all replicas in the mesh. Default: []

Returns

The payloads exchanged by all the other cores, with the payload of core ordinal i at position i in the returned tuple.

Example

>>> import torch_xla.core.xla_model as xm
>>> xm.rendezvous('example')
torch_xla.core.xla_model.mesh_reduce(tag: str, data, reduce_fn: Callable[[...], Any]) Union[Any, ToXlaTensorArena][source]

Performs an out-of-graph client mesh reduction.

Parameters
  • tag (string) – The name of the rendezvous to join.

  • data – The data to be reduced. The reduce_fn callable will receive a list with the copies of the same data coming from all the mesh client processes (one per core).

  • reduce_fn (callable) – A function which receives a list of data-like objects and returns the reduced result.

Returns

The reduced value.

Example

>>> import torch_xla.core.xla_model as xm
>>> import numpy as np
>>> accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
torch_xla.core.xla_model.set_rng_state(seed: int, device: Optional[str] = None)[source]

Sets the random number generator state.

Parameters
  • seed (python:integer) – The state to be set.

  • device (string, optional) – The device where the RNG state needs to be set. If missing the default device seed will be set.

torch_xla.core.xla_model.get_rng_state(device: Optional[str] = None) int[source]

Gets the current running random number generator state.

Parameters

device (string, optional) – The device whose RNG state needs to be retrieved. If missing the default device seed will be set.

Returns

The RNG state, as integer.

torch_xla.core.xla_model.get_memory_info(device: Optional[device] = None) MemoryInfo[source]

Retrieves the device memory usage.

Parameters
  • device – Optional[torch.device] The device whose memory information are requested.

  • device. (If not passed will use the default) –

Returns

MemoryInfo dict with memory usage for the given device.

Example

>>> xm.get_memory_info()
{'bytes_used': 290816, 'bytes_limit': 34088157184, 'peak_bytes_used': 500816}
torch_xla.core.xla_model.get_stablehlo(tensors: Optional[List[Tensor]] = None) str[source]

Get StableHLO for the computation graph in string format.

If tensors is not empty, the graph with tensors as outputs will be dump. If tensors is empty, the whole computation graph will be dump.

For inference graph, it is recommended to pass the model outputs to tensors. For training graph, it is not straightforward to identify the “outputs”. Using empty tensors is recommended.

To enable source line info in StableHLO, please set env var XLA_HLO_DEBUG=1.

Parameters

tensors (list[torch.Tensor], optional) – Tensors that represent the output/root of the StableHLO graph.

Returns

StableHLO Module in string format.

torch_xla.core.xla_model.get_stablehlo_bytecode(tensors: Optional[Tensor] = None) bytes[source]

Get StableHLO for the computation graph in bytecode format.

If tensors is not empty, the graph with tensors as outputs will be dump. If tensors is empty, the whole computation graph will be dump.

For inference graph, it is recommended to pass the model outputs to tensors. For training graph, it is not straightforward to identify the “outputs”. Using empty tensors is recommended.

Parameters

tensors (list[torch.Tensor], optional) – Tensors that represent the output/root of the StableHLO graph.

Returns

StableHLO Module in bytecode format.

distributed

class torch_xla.distributed.parallel_loader.MpDeviceLoader(loader, device, **kwargs)[source]

Wraps an existing PyTorch DataLoader with background data upload.

This class should only be using with multi-processing data parallelism. It will wrap the dataloader passed in with ParallelLoader and return the per_device_loader for the current device.

Parameters
  • loader (torch.utils.data.DataLoader) – The PyTorch DataLoader to be wrapped.

  • device (torch.device…) – The device where the data has to be sent.

  • kwargs – Named arguments for the ParallelLoader constructor.

Example

>>> device = torch_xla.device()
>>> train_device_loader = MpDeviceLoader(train_loader, device)
torch_xla.distributed.xla_multiprocessing.spawn(fn, args=(), nprocs=None, join=True, daemon=False, start_method='spawn')[source]

Enables multi processing based replication.

Parameters
  • fn (callable) – The function to be called for each device which takes part of the replication. The function will be called with a first argument being the global index of the process within the replication, followed by the arguments passed in args.

  • args (tuple) – The arguments for fn. Default: Empty tuple

  • nprocs (python:int) – The number of processes/devices for the replication. At the moment, if specified, can be either 1 or None (which would automatically converted to the maximum number of devices). Other numbers would result in ValueError.

  • join (bool) – Whether the call should block waiting for the completion of the processes which have being spawned. Default: True

  • daemon (bool) – Whether the processes being spawned should have the daemon flag set (see Python multi-processing API). Default: False

  • start_method (string) – The Python multiprocessing process creation method. Default: spawn

Returns

The same object returned by the torch.multiprocessing.spawn API. If nprocs is 1 the fn function will be called directly, and the API will return None.

spmd

torch_xla.distributed.spmd.mark_sharding(t: Union[Tensor, XLAShardedTensor], mesh: Mesh, partition_spec: Tuple[Optional[Union[Tuple, int, str]], ...]) XLAShardedTensor[source]

Annotates the tensor provided with XLA partition spec. Internally, it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass.

Parameters
  • t (Union[torch.Tensor, XLAShardedTensor]) – input tensor to be annotated with partition_spec.

  • mesh (Mesh) – describes the logical XLA device topology and the underlying device IDs.

  • partition_spec (Tuple[Tuple, python:int, str, None]) – A tuple of device_mesh dimension index or None. Each index is an int, str if the mesh axis is named, or tuple of int or str. This specifies how each input rank is sharded (index to mesh_shape) or replicated (None). When a tuple is specified, the corresponding input tensor axis will be sharded along all logical axes in the tuple. Note that the order the mesh axes are specified in the tuple will impact the resulting sharding.

  • dynamo_custom_op (bool) – if set to True, it calls the dynamo custom op variant of mark_sharding to make itself recognizeable and traceable by dynamo.

Example

>>> import torch_xla.runtime as xr
>>> import torch_xla.distributed.spmd as xs
>>> mesh_shape = (4, 2)
>>> num_devices = xr.global_runtime_device_count()
>>> device_ids = np.array(range(num_devices))
>>> mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
>>> input = torch.randn(8, 32).to(xm.xla_device())
>>> xs.mark_sharding(input, mesh, (0, None)) # 4-way data parallel
>>> linear = nn.Linear(32, 10).to(xm.xla_device())
>>> xs.mark_sharding(linear.weight, mesh, (None, 1)) # 2-way model parallel
torch_xla.distributed.spmd.clear_sharding(t: Union[Tensor, XLAShardedTensor]) Tensor[source]

Clear sharding annotation from the input tensor and return a cpu casted tensor. This is a in place operation but will also return the same torch.Tensor back.

Parameters

t (Union[torch.Tensor, XLAShardedTensor]) – Tensor that we want to clear the sharding

Returns

tensor that without sharding.

Return type

t (torch.Tensor)

Example

>>> import torch_xla.distributed.spmd as xs
>>> torch_xla.runtime.use_spmd()
>>> t1 = torch.randn(8,8).to(torch_xla.device())
>>> mesh = xs.get_1d_mesh()
>>> xs.mark_sharding(t1, mesh, (0, None))
>>> xs.clear_sharding(t1)
torch_xla.distributed.spmd.set_global_mesh(mesh: Mesh)[source]

Set the global mesh that can be used for the current process.

Parameters

mesh – (Mesh) Mesh object that will be the global mesh.

Example

>>> import torch_xla.distributed.spmd as xs
>>> mesh = xs.get_1d_mesh("data")
>>> xs.set_global_mesh(mesh)
torch_xla.distributed.spmd.get_global_mesh() Optional[Mesh][source]

Get the global mesh for the current process.

Returns

(Optional[Mesh]) Mesh object if global mesh is set, otherwise return None.

Return type

mesh

Example

>>> import torch_xla.distributed.spmd as xs
>>> xs.get_global_mesh()
torch_xla.distributed.spmd.get_1d_mesh(axis_name: Optional[str] = None) Mesh[source]

Helper function to return the mesh with all devices in one dimension.

Parameters

axis_name – (Optional[str]) optional string to represent the axis name of the mesh

Returns

Mesh object

Return type

Mesh

Example

>>> # This example is assuming 1 TPU v4-8
>>> import torch_xla.distributed.spmd as xs
>>> mesh = xs.get_1d_mesh("data")
>>> print(mesh.mesh_shape)
(4,)
>>> print(mesh.axis_names)
('data',)
class torch_xla.distributed.spmd.Mesh(device_ids: Union[ndarray, List], mesh_shape: Tuple[int, ...], axis_names: Optional[Tuple[str, ...]] = None)[source]

Describe the logical XLA device topology mesh and the underlying resources.

Parameters
  • device_ids (Union[np.ndarray, List]) – A raveled list of devices (IDs) in a custom order. The list is reshaped to an mesh_shape array, filling the elements using C-like index order.

  • mesh_shape (Tuple[python:int, ...]) – A int tuple describing the logical topology shape of the device mesh, and each element describes the number of devices in the corresponding axis.

  • axis_names (Tuple[str, ...]) – A sequence of resource axis names to be assigned to the dimensions of the devices argument. Its length should match the rank of devices.

Example

>>> mesh_shape = (4, 2)
>>> num_devices = len(xm.get_xla_supported_devices())
>>> device_ids = np.array(range(num_devices))
>>> mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
>>> mesh.get_logical_mesh()
>>> array([[0, 1],
          [2, 3],
          [4, 5],
          [6, 7]])
>>> mesh.shape()
OrderedDict([('x', 4), ('y', 2)])
class torch_xla.distributed.spmd.HybridMesh(*, ici_mesh_shape: Tuple[int, ...], dcn_mesh_shape: Optional[Tuple[int, ...]] = None, axis_names: Optional[Tuple[str, ...]] = None)[source]
Creates a hybrid device mesh of devices connected with ICI and DCN networks.

The shape of logical mesh should be ordered by increasing network-intensity e.g. [replica, data, model] where mdl has the most network communication requirements.

Parameters
  • ici_mesh_shape – shape of the logical mesh for inner connected devices.

  • dcn_mesh_shape – shape of logical mesh for outer connected devices.

Example

>>> # This example is assuming 2 slices of v4-8.
>>> ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor)
>>> dcn_mesh_shape = (2, 1, 1)
>>> mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor'))
>>> print(mesh.shape())
>>> >> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)])

experimental

torch_xla.experimental.eager_mode(enable: bool)[source]

Configure torch_xla’s default executation mode.

Under eager mode only functions that was `torch_xla.compile`d will be traced and compiled. Other torch ops will be executed eagerly.

debug

torch_xla.debug.metrics.metrics_report()[source]

Retrieves a string containing the full metrics and counters report.

torch_xla.debug.metrics.short_metrics_report(counter_names: Optional[list] = None, metric_names: Optional[list] = None)[source]

Retrieves a string containing the full metrics and counters report.

Parameters
  • counter_names (list) – The list of counter names whose data needs to be printed.

  • metric_names (list) – The list of metric names whose data needs to be printed.

torch_xla.debug.metrics.counter_names()[source]

Retrieves all the currently active counter names.

torch_xla.debug.metrics.counter_value(name)[source]

Returns the value of an active counter.

Parameters

name (string) – The name of the counter whose value needs to be retrieved.

Returns

The counter value as integer.

torch_xla.debug.metrics.metric_names()[source]

Retrieves all the currently active metric names.

torch_xla.debug.metrics.metric_data(name)[source]

Returns the data of an active metric.

Parameters

name (string) – The name of the metric whose data needs to be retrieved.

Returns

The metric data, which is a tuple of (TOTAL_SAMPLES, ACCUMULATOR, SAMPLES). The TOTAL_SAMPLES is the total number of samples which have been posted to the metric. A metric retains only a given number of samples (in a circular buffer). The ACCUMULATOR is the sum of the samples over TOTAL_SAMPLES. The SAMPLES is a list of (TIME, VALUE) tuples.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources