Skip to content

Commit

Permalink
[Data] Support async callable classes in map_batches() (#46129)
Browse files Browse the repository at this point in the history
Add support for passing CallableClass with asynchronous generator
`__call__` method to `Dataset.map_batches()` API. This is useful for
streaming outputs from asynchronous generators as they become available
to maximize throughput.

---------

Signed-off-by: Scott Lee <sjl@anyscale.com>
  • Loading branch information
scottjlee committed Jun 25, 2024
1 parent a709b8f commit f75ad5d
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 50 deletions.
3 changes: 1 addition & 2 deletions python/ray/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@

# Module-level cached global functions for callable classes. It needs to be defined here
# since it has to be process-global across cloudpickled funcs.
_cached_fn = None
_cached_cls = None
_map_actor_context = None

configure_logging()

Expand Down
214 changes: 166 additions & 48 deletions python/ray/data/_internal/planner/plan_udf_map_op.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import asyncio
import collections
import inspect
import queue
from threading import Thread
from types import GeneratorType
from typing import Any, Callable, Iterable, Iterator, List, Optional

Expand Down Expand Up @@ -43,6 +47,36 @@
from ray.util.rpdb import _is_ray_debugger_enabled


class _MapActorContext:
def __init__(
self,
udf_map_cls: UserDefinedFunction,
udf_map_fn: Callable[[Any], Any],
is_async: bool,
):
self.udf_map_cls = udf_map_cls
self.udf_map_fn = udf_map_fn
self.is_async = is_async
self.udf_map_asyncio_loop = None
self.udf_map_asyncio_thread = None

if is_async:
self._init_async()

def _init_async(self):
# Only used for callable class with async generator `__call__` method.
loop = asyncio.new_event_loop()

def run_loop():
asyncio.set_event_loop(loop)
loop.run_forever()

thread = Thread(target=run_loop)
thread.start()
self.udf_map_asyncio_loop = loop
self.udf_map_asyncio_thread = thread


def plan_udf_map_op(
op: AbstractUDFMap, physical_children: List[PhysicalOperator]
) -> MapOperator:
Expand Down Expand Up @@ -104,23 +138,53 @@ def _parse_op_fn(op: AbstractUDFMap):
fn_constructor_args = op._fn_constructor_args or ()
fn_constructor_kwargs = op._fn_constructor_kwargs or {}

op_fn = make_callable_class_concurrent(op_fn)
is_async_gen = inspect.isasyncgenfunction(op._fn.__call__)

def fn(item: Any) -> Any:
assert ray.data._cached_fn is not None
assert ray.data._cached_cls == op_fn
try:
return ray.data._cached_fn(item, *fn_args, **fn_kwargs)
except Exception as e:
_handle_debugger_exception(e)
# TODO(scottjlee): (1) support non-generator async functions
# (2) make the map actor async
if not is_async_gen:
op_fn = make_callable_class_concurrent(op_fn)

def init_fn():
if ray.data._cached_fn is None:
ray.data._cached_cls = op_fn
ray.data._cached_fn = op_fn(
*fn_constructor_args, **fn_constructor_kwargs
if ray.data._map_actor_context is None:
ray.data._map_actor_context = _MapActorContext(
udf_map_cls=op_fn,
udf_map_fn=op_fn(
*fn_constructor_args,
**fn_constructor_kwargs,
),
is_async=is_async_gen,
)

if is_async_gen:

async def fn(item: Any) -> Any:
assert ray.data._map_actor_context is not None
assert ray.data._map_actor_context.is_async

try:
return ray.data._map_actor_context.udf_map_fn(
item,
*fn_args,
**fn_kwargs,
)
except Exception as e:
_handle_debugger_exception(e)

else:

def fn(item: Any) -> Any:
assert ray.data._map_actor_context is not None
assert not ray.data._map_actor_context.is_async
try:
return ray.data._map_actor_context.udf_map_fn(
item,
*fn_args,
**fn_kwargs,
)
except Exception as e:
_handle_debugger_exception(e)

else:

def fn(item: Any) -> Any:
Expand Down Expand Up @@ -158,6 +222,7 @@ def _validate_batch_output(batch: Block) -> None:
np.ndarray,
collections.abc.Mapping,
pd.core.frame.DataFrame,
dict,
),
):
raise ValueError(
Expand Down Expand Up @@ -192,46 +257,99 @@ def _validate_batch_output(batch: Block) -> None:

def _generate_transform_fn_for_map_batches(
fn: UserDefinedFunction,
) -> MapTransformCallable[DataBatch, DataBatch]:
if inspect.iscoroutinefunction(fn):
# UDF is a callable class with async generator `__call__` method.
transform_fn = _generate_transform_fn_for_async_map_batches(fn)

else:

def transform_fn(
batches: Iterable[DataBatch], _: TaskContext
) -> Iterable[DataBatch]:
for batch in batches:
try:
if (
not isinstance(batch, collections.abc.Mapping)
and BlockAccessor.for_block(batch).num_rows() == 0
):
# For empty input blocks, we directly ouptut them without
# calling the UDF.
# TODO(hchen): This workaround is because some all-to-all
# operators output empty blocks with no schema.
res = [batch]
else:
res = fn(batch)
if not isinstance(res, GeneratorType):
res = [res]
except ValueError as e:
read_only_msgs = [
"assignment destination is read-only",
"buffer source array is read-only",
]
err_msg = str(e)
if any(msg in err_msg for msg in read_only_msgs):
raise ValueError(
f"Batch mapper function {fn.__name__} tried to mutate a "
"zero-copy read-only batch. To be able to mutate the "
"batch, pass zero_copy_batch=False to map_batches(); "
"this will create a writable copy of the batch before "
"giving it to fn. To elide this copy, modify your mapper "
"function so it doesn't try to mutate its input."
) from e
else:
raise e from None
else:
for out_batch in res:
_validate_batch_output(out_batch)
yield out_batch

return transform_fn


def _generate_transform_fn_for_async_map_batches(
fn: UserDefinedFunction,
) -> MapTransformCallable[DataBatch, DataBatch]:
def transform_fn(
batches: Iterable[DataBatch], _: TaskContext
input_iterable: Iterable[DataBatch], _: TaskContext
) -> Iterable[DataBatch]:
for batch in batches:
try:
if (
not isinstance(batch, collections.abc.Mapping)
and BlockAccessor.for_block(batch).num_rows() == 0
):
# For empty input blocks, we directly ouptut them without
# calling the UDF.
# TODO(hchen): This workaround is because some all-to-all
# operators output empty blocks with no schema.
res = [batch]
else:
res = fn(batch)
if not isinstance(res, GeneratorType):
res = [res]
except ValueError as e:
read_only_msgs = [
"assignment destination is read-only",
"buffer source array is read-only",
]
err_msg = str(e)
if any(msg in err_msg for msg in read_only_msgs):
raise ValueError(
f"Batch mapper function {fn.__name__} tried to mutate a "
"zero-copy read-only batch. To be able to mutate the "
"batch, pass zero_copy_batch=False to map_batches(); "
"this will create a writable copy of the batch before "
"giving it to fn. To elide this copy, modify your mapper "
"function so it doesn't try to mutate its input."
) from e
else:
raise e from None
# Use a queue to store outputs from async generator calls.
# We will put output batches into this queue from async
# generators, and in the main event loop, yield them from
# the queue as they become available.
output_batch_queue = queue.Queue()

async def process_batch(batch: DataBatch):
output_batch_iterator = await fn(batch)
# As soon as results become available from the async generator,
# put them into the result queue so they can be yielded.
async for output_batch in output_batch_iterator:
output_batch_queue.put(output_batch)

async def process_all_batches():
loop = ray.data._map_actor_context.udf_map_asyncio_loop
tasks = [loop.create_task(process_batch(x)) for x in input_iterable]

ctx = ray.data.DataContext.get_current()
if ctx.execution_options.preserve_order:
for task in tasks:
await task()
else:
for out_batch in res:
_validate_batch_output(out_batch)
yield out_batch
for task in asyncio.as_completed(tasks):
await task

# Use the existing event loop to create and run Tasks to process each batch
loop = ray.data._map_actor_context.udf_map_asyncio_loop
future = asyncio.run_coroutine_threadsafe(process_all_batches(), loop)

# Yield results as they become available.
while not future.done():
# Here, `out_batch` is a one-row output batch
# from the async generator, corresponding to a
# single row from the input batch.
out_batch = output_batch_queue.get()
_validate_batch_output(out_batch)
yield out_batch

return transform_fn

Expand Down
34 changes: 34 additions & 0 deletions python/ray/data/tests/test_map.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import itertools
import math
import os
Expand Down Expand Up @@ -1057,6 +1058,39 @@ def test_nonserializable_map_batches(shutdown_only):
x.map_batches(lambda _: lock).take(1)


def test_map_batches_async_generator(shutdown_only):
ray.shutdown()
ray.init(num_cpus=10)

async def sleep_and_yield(i):
print("sleep", i)
await asyncio.sleep(i % 5)
print("yield", i)
return {"input": [i], "output": [2**i]}

class AsyncActor:
def __init__(self):
pass

async def __call__(self, batch):
tasks = [asyncio.create_task(sleep_and_yield(i)) for i in batch["id"]]
for task in tasks:
yield await task

n = 10
ds = ray.data.range(n, override_num_blocks=2)
ds = ds.map(lambda x: x)
ds = ds.map_batches(AsyncActor, batch_size=1, concurrency=1, max_concurrency=2)

start_t = time.time()
output = ds.take_all()
runtime = time.time() - start_t
assert runtime < sum(range(n)), runtime

expected_output = [{"input": i, "output": 2**i} for i in range(n)]
assert output == expected_output, (output, expected_output)


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit f75ad5d

Please sign in to comment.