Collator¶
- class torchdata.datapipes.iter.Collator(datapipe: ~IterDataPipe, collate_fn: ~Callable = <function default_collate>)¶
Collates samples from DataPipe to Tensor(s) by a custom collate function (functional name:
collate
). By default, it usestorch.utils.data.default_collate()
.Note
While writing a custom collate function, you can import
torch.utils.data.default_collate()
for the default behavior and functools.partial to specify any additional arguments.- Parameters:
datapipe – Iterable DataPipe being collated
collate_fn – Customized collate function to collect and combine data or a batch of data. Default function collates to Tensor(s) based on data type.
- Example: Convert integer data to float Tensor
>>> class MyIterDataPipe(torch.utils.data.IterDataPipe): ... def __init__(self, start, end): ... super(MyIterDataPipe).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... ... def __len__(self): ... return self.end - self.start ... >>> ds = MyIterDataPipe(start=3, end=7) >>> print(list(ds)) [3, 4, 5, 6] >>> def collate_fn(batch): ... return torch.tensor(batch, dtype=torch.float) ... >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn) >>> print(list(collated_ds)) [tensor(3.), tensor(4.), tensor(5.), tensor(6.)]