Shortcuts

Collator

class torchdata.datapipes.iter.Collator(datapipe: ~IterDataPipe, collate_fn: ~Callable = <function default_collate>)

Collates samples from DataPipe to Tensor(s) by a custom collate function (functional name: collate). By default, it uses torch.utils.data.default_collate().

Note

While writing a custom collate function, you can import torch.utils.data.default_collate() for the default behavior and functools.partial to specify any additional arguments.

Parameters:
  • datapipe – Iterable DataPipe being collated

  • collate_fn – Customized collate function to collect and combine data or a batch of data. Default function collates to Tensor(s) based on data type.

Example: Convert integer data to float Tensor
>>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
...     def __init__(self, start, end):
...         super(MyIterDataPipe).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
...     def __len__(self):
...         return self.end - self.start
...
>>> ds = MyIterDataPipe(start=3, end=7)
>>> print(list(ds))
[3, 4, 5, 6]
>>> def collate_fn(batch):
...     return torch.tensor(batch, dtype=torch.float)
...
>>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
>>> print(list(collated_ds))
[tensor(3.), tensor(4.), tensor(5.), tensor(6.)]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources