Shortcuts

BatchAsyncMapper

class torchdata.datapipes.iter.BatchAsyncMapper(source_datapipe, async_fn: Callable, batch_size: int, input_col=None, output_col=None, max_concurrency: int = 32, flatten: bool = True)

Combines elements from the source DataPipe to batches and applies a coroutine function over each element within the batch concurrently, then flattens the outpus to a single, unnested IterDataPipe (functional name: async_map_batches).

Parameters:
  • source_datapipe – Source IterDataPipe

  • async_fn – The coroutine function to be applied to each batch of data

  • batch_size – The size of batch to be aggregated from source_datapipe

  • input_col

    Index or indices of data which fn is applied, such as:

    • None as default to apply fn to the data directly.

    • Integer(s) is used for list/tuple.

    • Key(s) is used for dict.

  • output_col

    Index of data where result of fn is placed. output_col can be specified only when input_col is not None

    • None as default to replace the index that input_col specified; For input_col with multiple indices, the left-most one is used, and other indices will be removed.

    • Integer is used for list/tuple. -1 represents to append result at the end.

    • Key is used for dict. New key is acceptable.

  • max_concurrency – Maximum concurrency to call async functions. (Default: 32)

  • flatten – Determine if the batches get flatten in the end (Default: True) If False, outputs will be in batches of size batch_size

Example

>>> from torchdata.datapipes.iter import IterableWrapper
>>> async def mul_ten(x):
...     await asyncio.sleep(1)
...     return x * 10
>>> dp = IterableWrapper(range(50))
>>> dp = dp.async_map_batches(mul_ten, 16)
>>> list(dp)
[0, 10, 20, 30, ...]
>>> dp = IterableWrapper([(i, i) for i in range(50)])
>>> dp = dp.async_map_batches(mul_ten, 16, input_col=1)
>>> list(dp)
[(0, 0), (1, 10), (2, 20), (3, 30), ...]
>>> dp = IterableWrapper([(i, i) for i in range(50)])
>>> dp = dp.async_map_batches(mul_ten, 16, input_col=1, output_col=-1)
>>> list(dp)
[(0, 0, 0), (1, 1, 10), (2, 2, 20), (3, 3, 30), ...]
# Async fetching html from remote
>>> from aiohttp import ClientSession
>>> async def fetch_html(url: str, **kwargs):
...     async with ClientSession() as session:
...         resp = await session.request(method="GET", url=url, **kwargs)
...         resp.raise_for_status()
...         html = await resp.text()
...     return html
>>> dp = IterableWrapper(urls)
>>> dp = dp.async_map_batches(fetch_html, 16)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources