Skip to content

Commit

Permalink
[Reland][HigherOrderOp] make MapHigherOrder create map_impl (pytorch#…
Browse files Browse the repository at this point in the history
…115561)

This is a reland of pytorch#115205, which gets reverted due to internal test failure.

Pull Request resolved: pytorch#115561
Approved by: https://github.com/angelayi
  • Loading branch information
ydwu4 authored and guilhermeleobas committed Dec 18, 2023
1 parent 2c818ec commit b806631
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
34 changes: 34 additions & 0 deletions test/dynamo/test_higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,40 @@ def inner(x, y):
# get_item call created by the flatten/unflatten logic in HOP speculation.
self.assertEqual(cnt.op_count, ifdynstaticdefault(3, 4))

def test_map_lowers_to_graph(self):
backend = EagerAndRecordGraphs()
cnt = CompileCounterWithBackend(backend)

def fn(x, y):
def inner(x, y):
return torch.sin(x + y)

return control_flow.map(inner, x, y.size(0))

x = torch.randn(3, 1)
y = torch.randn(3, 1)
compiled_fn = torch.compile(fn, backend=backend)(x, y)
self.assertEqual(len(backend.graphs), 1)
graph = backend.graphs[0]
# Dynamic shapes produce a slightly different graph.
if check_dynamic_shape_capture():
return

# TODO(yidi): remove the getitem = l_x_.__getitem__(0) call. It's
# created accidently when we create sample inputs based on the 0-th slice
# before specualting the f in MapHigherOrder.
self.assertExpectedInline(
graph.code.strip(),
"""\
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
getitem = l_x_.__getitem__(0)
map_body_0 = self.map_body_0
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, l_x_, 3); map_body_0 = l_x_ = None
getitem_1 = map_impl[0]; map_impl = None
return (getitem_1,)""",
)

def test_cond_subgraph_name_is_valid(self):
backend = EagerAndRecordGraphs()
cnt = CompileCounterWithBackend(backend)
Expand Down
3 changes: 2 additions & 1 deletion torch/_dynamo/variables/higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,11 +765,12 @@ def call_function(
body_node = make_attr(tx, body_name)
p_args = (
body_node,
1, # right now we only supports num_mapped = 1
*(arg.as_proxy() for arg in args[1:]),
*(arg for arg in body_lifted_freevars.keys()),
)
return _call_function_and_unflatten_output(
tx, self.value, p_args, {}, body_r, body_spec
tx, torch.ops.higher_order.map_impl, p_args, {}, body_r, body_spec
)


Expand Down

0 comments on commit b806631

Please sign in to comment.