BatchAsyncMapper¶
- class torchdata.datapipes.iter.BatchAsyncMapper(source_datapipe, async_fn: Callable, batch_size: int, input_col=None, output_col=None, max_concurrency: int = 32, flatten: bool = True)¶
Combines elements from the source DataPipe to batches and applies a coroutine function over each element within the batch concurrently, then flattens the outpus to a single, unnested IterDataPipe (functional name:
async_map_batches
).- Parameters:
source_datapipe – Source IterDataPipe
async_fn – The coroutine function to be applied to each batch of data
batch_size – The size of batch to be aggregated from
source_datapipe
input_col –
Index or indices of data which
fn
is applied, such as:None
as default to applyfn
to the data directly.Integer(s) is used for list/tuple.
Key(s) is used for dict.
output_col –
Index of data where result of
fn
is placed.output_col
can be specified only wheninput_col
is notNone
None
as default to replace the index thatinput_col
specified; Forinput_col
with multiple indices, the left-most one is used, and other indices will be removed.Integer is used for list/tuple.
-1
represents to append result at the end.Key is used for dict. New key is acceptable.
max_concurrency – Maximum concurrency to call async functions. (Default:
32
)flatten – Determine if the batches get flatten in the end (Default:
True
) IfFalse
, outputs will be in batches of sizebatch_size
Example
>>> from torchdata.datapipes.iter import IterableWrapper >>> async def mul_ten(x): ... await asyncio.sleep(1) ... return x * 10 >>> dp = IterableWrapper(range(50)) >>> dp = dp.async_map_batches(mul_ten, 16) >>> list(dp) [0, 10, 20, 30, ...] >>> dp = IterableWrapper([(i, i) for i in range(50)]) >>> dp = dp.async_map_batches(mul_ten, 16, input_col=1) >>> list(dp) [(0, 0), (1, 10), (2, 20), (3, 30), ...] >>> dp = IterableWrapper([(i, i) for i in range(50)]) >>> dp = dp.async_map_batches(mul_ten, 16, input_col=1, output_col=-1) >>> list(dp) [(0, 0, 0), (1, 1, 10), (2, 2, 20), (3, 3, 30), ...] # Async fetching html from remote >>> from aiohttp import ClientSession >>> async def fetch_html(url: str, **kwargs): ... async with ClientSession() as session: ... resp = await session.request(method="GET", url=url, **kwargs) ... resp.raise_for_status() ... html = await resp.text() ... return html >>> dp = IterableWrapper(urls) >>> dp = dp.async_map_batches(fetch_html, 16)