torcheval.metrics.Cat¶
- class torcheval.metrics.Cat(*, dim: int = 0, device: device | None = None)¶
Concatenate all input tensors along dimension dim. Its functional version is
torch.cat(input)
.All input tensors to
Cat.update()
must either have the same shape (except in the concatenating dimension) or be empty.Zero-dimensional tensor is not a valid input of
Cat.update()
.torch.flatten()
can be used to flatten zero-dimensional into an one-dimensional tensor before passing inCat.update()
.Examples:
>>> import torch >>> from torcheval.metrics import Cat >>> metric = Cat(dim=1) >>> metric.update(torch.tensor([[1, 2], [3, 4]])) >>> metric.compute() tensor([[1, 2], [3, 4]])) >>> metric.update(torch.tensor([[5, 6], [7, 8]]))).compute() tensor([[1, 2, 5, 6], [3, 4, 7, 8]])) >>> metric.reset() >>> metric.update(torch.tensor([0])).compute() tensor([0])
- __init__(*, dim: int = 0, device: device | None = None) None ¶
Initialize a Cat metric object.
- Parameters:
dim – The dimension along which to concatenate, as in
torch.cat()
.
Methods
__init__
(*[, dim, device])Initialize a Cat metric object.
compute
()Return the concatenated inputs.
load_state_dict
(state_dict[, strict])Loads metric state variables from state_dict.
merge_state
(metrics)Implement this method to update the current metric's state variables to be the merged states of the current metric and input metrics.
reset
()Reset the metric state variables to their default value.
state_dict
()Save metric state variables in state_dict.
to
(device, *args, **kwargs)Move tensors in metric state variables to device.
update
(input)Implement this method to update the state variables of your metric class.
Attributes
device
The last input device of
Metric.to()
.