Skip to content

Commit

Permalink
[export][reland] Remove runtime assertion pass (pytorch#115597)
Browse files Browse the repository at this point in the history
Summary:
Reland of pytorch#115196
D52054112 to fix internal failures.

Test Plan: CI

Differential Revision: D52054110

Pull Request resolved: pytorch#115597
Approved by: https://github.com/ydwu4, https://github.com/zhxchen17
  • Loading branch information
angelayi authored and pytorchmergebot committed Dec 15, 2023
1 parent 7d4ccd7 commit 8e2d63c
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 219 deletions.
2 changes: 1 addition & 1 deletion test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2399,7 +2399,7 @@ def foo(x, y):

example_inputs = (copy(x), y)
ep = torch._export._export(foo, example_inputs, constraints=constraints)
with self.assertRaisesRegex(RuntimeError, "Input.*shape.*specialized at 2"):
with self.assertRaisesRegex(RuntimeError, "input.*shape.*to be equal to 2"):
ep(torch.randn(3), y)

dim0_x, dim0_y = torch.export.dims("dim0_x", "dim0_y")
Expand Down
23 changes: 11 additions & 12 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
torch.export.export(m, (a,), dynamic_shapes=dynamic_shapes)
em = torch.export.export(m, (a,))
x = torch.randn(3, 5)
with self.assertRaisesRegex(RuntimeError, "\\[1\\] is specialized at 4"):
with self.assertRaisesRegex(RuntimeError, "shape\[1\] to be equal to 4, but got 5"):
em(x)

def test_not_correct_dim(self):
Expand Down Expand Up @@ -1206,13 +1206,13 @@ def f(x, y):
torch.allclose(exported(torch.ones(8, 5), 5), f(torch.ones(8, 5), 5))
)
with self.assertRaisesRegex(
RuntimeError, "is specialized to be 5 at tracing time"
RuntimeError, "Expected input arg1 to be equal to 5, but got 6"
):
_ = exported(torch.ones(8, 5), 6)

exported = torch.export.export(f, (tensor_inp, 5.0), dynamic_shapes=dynamic_shapes)
with self.assertRaisesRegex(
RuntimeError, "is specialized to be 5.0 at tracing time"
RuntimeError, "Expected input arg1 to be equal to 5.0, but got 6.0"
):
_ = exported(torch.ones(7, 5), 6.0)

Expand All @@ -1225,7 +1225,7 @@ def g(a, b, mode):

inps = (torch.randn(4, 4), torch.randn(4), "trunc")
exported = export(g, inps)
with self.assertRaisesRegex(RuntimeError, "is specialized to be trunc at"):
with self.assertRaisesRegex(RuntimeError, "to be equal to trunc, but got floor"):
_ = exported(torch.randn(4, 4), torch.randn(4), "floor")
self.assertTrue(torch.allclose(exported(*inps), g(*inps)))

Expand Down Expand Up @@ -1306,7 +1306,7 @@ def forward(self, x):
dim0_x = torch.export.Dim("dim0_x")
exported = torch.export.export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x}})
reexported = torch.export.export(exported, (inp,))
with self.assertRaisesRegex(RuntimeError, "shape\[0\] is specialized at 5"):
with self.assertRaisesRegex(RuntimeError, "shape\[0\] to be equal to 5, but got 7"):
reexported(torch.ones(7, 5))

reexported = torch.export.export(exported, (inp,), dynamic_shapes=({0: dim0_x},))
Expand All @@ -1315,7 +1315,7 @@ def forward(self, x):
# can't retrace with invalid inputs with respect to the original ExportedProgram
dim0_x_v2 = torch.export.Dim("dim0_x_v2", min=3)
exported_v2 = torch.export.export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x_v2}})
with self.assertRaisesRegex(RuntimeError, "shape\[1\] is specialized at 5"):
with self.assertRaisesRegex(RuntimeError, "Expected input l_x_.shape\[0\] to be >= 3, but got 2"):
torch.export.export(exported_v2, (torch.randn(2, 2),))

@testing.expectedFailureSerDer
Expand Down Expand Up @@ -1453,7 +1453,7 @@ def forward(self, x):
self.assertEqual(len(ep.state_dict), 1)
self.assertEqual(len(ep.tensor_constants), 2)

inp = (torch.randn(1),)
inp = (torch.tensor(5),)
self.assertTrue(torch.allclose(ep(*inp), Foo()(*inp)))

transform = ep.run_decompositions()
Expand Down Expand Up @@ -1620,7 +1620,7 @@ def foo(a, b):
self.assertEqual(ep(*test_inp), foo(*test_inp))

ep_v2 = torch.export.export(foo, (torch.randn(4, 4), torch.randn(4, 4)), dynamic_shapes=(None, None))
with self.assertRaisesRegex(RuntimeError, "shape\[0\] is specialized at 4"):
with self.assertRaisesRegex(RuntimeError, "shape\[0\] to be equal to 4, but got 7"):
ep_v2(*test_inp)

def test_constant_output(self):
Expand Down Expand Up @@ -1693,8 +1693,7 @@ def dynamify_inp(x):

test_inp = ((torch.randn(4, 4), torch.randn(2, 4)), torch.randn(4, 4))
with self.assertRaisesRegex(
RuntimeError,
"shape\[0\] is outside of specified dynamic range \[3, inf\]"
RuntimeError, "shape\[0\] to be >= 3, but got 2"
):
ep(*test_inp)

Expand Down Expand Up @@ -1724,10 +1723,10 @@ def forward(self, x):
inp = torch.randn(4, 4)
gm = capture_pre_autograd_graph(Foo(), (inp,), constraints=[dynamic_dim(inp, 0) >= 3])

with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
with self.assertRaisesRegex(RuntimeError, "Expected input arg0_1.shape\[0\] to be >= 3, but got 2"):
gm(torch.randn(2, 2))

with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
with self.assertRaisesRegex(RuntimeError, "Expected input arg0_1.shape\[0\] to be >= 3, but got 2"):
torch.export.export(gm, (torch.randn(2, 2),))

ep = torch.export.export(gm, (torch.randn(5, 4),), dynamic_shapes=({0: torch.export.Dim("dim", min=3)},))
Expand Down
14 changes: 7 additions & 7 deletions test/export/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def forward(self, x):
dim1_x = torch.export.Dim("dim1_x", min=2, max=6)
ep = torch.export.export(M(), (x,), dynamic_shapes={"x": {1: dim1_x}})

with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"):
with self.assertRaisesRegex(RuntimeError, r"Expected input l_x_.shape\[1\] to be <= 6, but got 7"):
ep(torch.zeros(2, 7, 3))

self.assertEqual(ep(torch.ones(2, 4, 3)), M().forward(torch.ones(2, 4, 3)))
Expand All @@ -99,10 +99,10 @@ def forward(self, x, y):
M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": {0: dim0_y}}
)

with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"):
with self.assertRaisesRegex(RuntimeError, r"Expected input l_x_.shape\[1\] to be <= 6, but got 7"):
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))

with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"):
with self.assertRaisesRegex(RuntimeError, r"Expected input l_y_.shape\[0\] to be >= 3, but got 2"):
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))

def test_runtime_assert_some_dims_not_specified(self) -> None:
Expand All @@ -123,12 +123,12 @@ def forward(self, x, y):
M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": None}
)

with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"):
with self.assertRaisesRegex(RuntimeError, r"Expected input l_x_.shape\[1\] to be <= 6, but got 7"):
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))

# y is specialized to 5
with self.assertRaisesRegex(
RuntimeError, r"shape\[0\] is specialized at 5"
RuntimeError, r"Expected input l_y_.shape\[0\] to be equal to 5, but got 2"
):
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))

Expand All @@ -152,12 +152,12 @@ def forward(self, x, y):
dim1_y = torch.export.Dim("dim1_y", min=3, max=6)
ep = torch.export.export(M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}})

with self.assertRaisesRegex(RuntimeError, r"shape\[1\] is specialized at 2"):
with self.assertRaisesRegex(RuntimeError, r"shape\[1\] to be equal to 2"):
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))

# y is specialized to 5
with self.assertRaisesRegex(
RuntimeError, r"shape\[0\] is specialized at 5"
RuntimeError, r"Expected input l_y_.shape\[0\] to be equal to 5, but got 2"
):
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))

Expand Down
4 changes: 2 additions & 2 deletions test/export/test_unflatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,11 @@ def forward(self, x):
return a

export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
with self.assertRaisesRegex(RuntimeError, ".shape\[1\] is specialized at 3"):
with self.assertRaisesRegex(RuntimeError, "Expected input l_x_.shape\[0\] to be equal to 2, but got 6"):
export_module(torch.randn(6, 6))

unflattened = export_module.module(flat=False)
with self.assertRaisesRegex(RuntimeError, ".shape\[1\] is specialized at 3"):
with self.assertRaisesRegex(RuntimeError, "Expected input l_x_.shape\[0\] to be equal to 2, but got 6"):
unflattened(torch.randn(6, 6))

def test_unflatten_with_inplace_compile(self):
Expand Down
168 changes: 2 additions & 166 deletions torch/_export/passes/add_runtime_assertions_for_constraints_pass.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
import copy
import math
import operator
import traceback
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, List, NamedTuple, Set, Tuple
from typing import Callable, Dict, List, NamedTuple, Set, Tuple

import sympy

import torch
import torch.fx
from torch.fx.experimental.symbolic_shapes import SymInt
from torch._export.pass_base import _ExportPassBase, ProxyValue, PassResult
from torch._subclasses.fake_tensor import FakeTensor
from torch.utils._sympy.value_ranges import ValueRanges


__all__ = ["_AddRuntimeAssertionsForConstraintsPass", "InputDim"]
__all__ = ["InputDim"]


class InputDim(NamedTuple):
Expand Down Expand Up @@ -150,163 +146,3 @@ def call(self, graph_module):
node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1))

return PassResult(val.graph_module, val.modified)


class _AddRuntimeAssertionsForConstraintsPass(_AddRuntimeAssertionsForInlineConstraintsPass):
def __init__(
self,
range_constraints: Dict[sympy.Symbol, ValueRanges],
equality_constraints: List[Tuple[InputDim, InputDim]],
):
super().__init__(range_constraints, equality_constraints)

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph_module = copy.deepcopy(graph_module)
graph = graph_module.graph

insert_loc = None
for node in graph.nodes:
if node.op != "placeholder":
continue
insert_loc = node
if insert_loc is None:
return super().call(graph_module)

# Add runtime asserts for input shape constraints. We do this after all
# placeholder nodes so that we can handle both (unary) predicates and
# (binary) relations.
inputdim_to_node: Dict[InputDim, torch.fx.Node] = OrderedDict()
for node in graph.nodes:
if node.op != "placeholder":
continue

if (
"val" not in node.meta or node.meta["val"] is None
):
continue

if not isinstance(node.meta["val"], FakeTensor):
# it has to be a prim value
self._insert_prim_assert_inplace(graph, node, node.meta["val"])
else:
fake_tensor_shape = node.meta["val"].shape
for dim, shape in enumerate(fake_tensor_shape):
with graph.inserting_after(insert_loc):
dim_node = graph.call_function(
torch.ops.aten.sym_size.int, (node, dim)
)
input_dim = InputDim(node.name, dim)
inputdim_to_node[input_dim] = dim_node
insert_loc = dim_node

if isinstance(shape, SymInt):
# If the shape is dynamic, add range assertions
symbol = shape.node._expr
if symbol in self.range_constraints:
self._insert_range_assert_inplace(
graph, input_dim, dim_node, self.range_constraints[symbol]
)
else:
# If no dynamism is specified, we assume all dimensions #
# are specialized
assert isinstance(shape, int)
self._insert_specialized_shape_assert_inplace(
graph, input_dim, dim_node, shape,
)

# Add runtime assertions on equality constraints on the inputs
if len(inputdim_to_node) > 0:
with graph.inserting_after(
list(inputdim_to_node.values())[-1]
):
self._insert_equality_assert_inplace(graph, inputdim_to_node)

return super().call(graph_module)

def _insert_specialized_shape_assert_inplace(
self, graph: torch.fx.Graph, input_dim: InputDim, dim_node: torch.fx.Node, shape: int,
):
assert_msg = f"Input {input_dim.input_name}.shape[{input_dim.dim}] is specialized at {shape}"
with graph.inserting_after(dim_node):
eq_node = graph.call_function(operator.eq, (dim_node, shape))
with graph.inserting_after(eq_node):
tensor_eq_node = graph.call_function(torch.ops.aten.scalar_tensor.default, (eq_node,))
with graph.inserting_after(tensor_eq_node):
_ = graph.call_function(torch.ops.aten._assert_async.msg, (tensor_eq_node, assert_msg))

def _insert_prim_assert_inplace(self, graph, node: torch.fx.Node, value: Any):
assert_msg = (
f"Input {node.name} is specialized to be {value} at tracing time,"
f"it is not supported to pass in a different value at run time."
)
with graph.inserting_after(node):
eq_node = graph.call_function(operator.eq, (node, value))
with graph.inserting_after(eq_node):
tensor_eq_node = graph.call_function(torch.ops.aten.scalar_tensor.default, (eq_node,))
with graph.inserting_after(tensor_eq_node):
_ = graph.call_function(torch.ops.aten._assert_async.msg, (tensor_eq_node, assert_msg))

def _insert_range_assert_inplace(
self, graph: torch.fx.Graph, input_dim: InputDim, dim_node: torch.fx.Node, range: ValueRanges
):
"""
Add runtime asserts for user-specified range constraints for
each placeholder's dynamic dimension.
"""

min_val, max_val = _convert_range_to_int(range)
assert_msg = (
f"Input {input_dim.input_name}.shape[{input_dim.dim}] is "
f"outside of specified dynamic range [{min_val}, {max_val}]"
)
# TODO (tmanlaibaatar) we are making an assumption that graph generated for
# input dim N >=2 generalizes to N < 2. Ideally we should check that:
# 1. if we can generalize to N < 2, not add any assertion saying N >= 2
# 2. If we can't generalize to N < 2, add an assertion saying N >= 2
# Above can be achieved via a separate pass.
with graph.inserting_after(dim_node):
if min_val > 2:
self._insert_assert_async_inplace(
graph, operator.ge, (dim_node, min_val), assert_msg,
)

if max_val < math.inf:
self._insert_assert_async_inplace(
graph, operator.le, (dim_node, max_val), assert_msg,
)

def _insert_equality_assert_inplace(
self,
graph: torch.fx.Graph,
inputdim_to_node: Dict[InputDim, torch.fx.Node],
):
for input_dim, other_input_dim in self.equality_constraints:
dim_node = inputdim_to_node[input_dim]
assert_msg = (
f"Input {input_dim.input_name}.shape[{input_dim.dim}] is "
f"not equal to input {other_input_dim.input_name}.shape[{other_input_dim.dim}]"
)

other_dim_node = inputdim_to_node[other_input_dim]
self._insert_assert_async_inplace(
graph,
operator.eq,
(dim_node, other_dim_node),
assert_msg
)

def _insert_assert_async_inplace(self, graph, operator, args, assert_msg):
"""
Inserts assert_async call_function nodes in the graph. This function is
called before we run the interpreter-based pass and does an inplace
insertion.
"""
cmp_node = graph.call_function(operator, args)
with graph.inserting_after(cmp_node):
cmp_tensor_node = graph.call_function(
torch.ops.aten.scalar_tensor.default, (cmp_node,)
)
with graph.inserting_after(cmp_tensor_node):
_ = graph.call_function(
torch.ops.aten._assert_async.msg, (cmp_tensor_node, assert_msg)
)
Loading

0 comments on commit 8e2d63c

Please sign in to comment.