Skip to content

Commit

Permalink
[export] Turn off output value from sources for export. (pytorch#115442)
Browse files Browse the repository at this point in the history
Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch#115442
Approved by: https://github.com/tugsbayasgalan
  • Loading branch information
zhxchen17 authored and pytorchmergebot committed Dec 12, 2023
1 parent af09fe2 commit f78f23d
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 33 deletions.
26 changes: 26 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1562,6 +1562,32 @@ def foo(a, b):
with self.assertRaisesRegex(RuntimeError, "shape\[0\] is specialized at 4"):
ep_v2(*test_inp)

def test_constant_output(self):
class ModuleConstant(torch.nn.Module):
def __init__(self):
super().__init__()
self.b = torch.randn(3, 2)

def forward(self):
return self.b

class ModuleNestedConstant(torch.nn.Module):
def __init__(self):
super().__init__()
self.bff = torch.randn(3, 2)

def forward(self, x, y):
return {"prediction": (x + y, self.bff)}

mod = ModuleConstant()
ep = torch.export.export(mod, ())
self.assertEqual(ep(), mod())

args = (torch.randn(3, 2), torch.randn(3, 2))
mod = ModuleNestedConstant()
ep = torch.export.export(mod, args)
self.assertEqual(ep(*args), mod(*args))

def test_non_arg_name_dynamic_shapes_api_with_kwarg(self):
def foo(a, b, kw1, kw2):
return a.sum() + b.sum() + kw1.sum() - kw2.sum()
Expand Down
8 changes: 6 additions & 2 deletions torch/_dynamo/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,18 @@ def __init__(
self.cell_and_freevars = self.tx.cell_and_freevars
self.new_var = self.tx.output.new_var
self.mutable_side_effects_from_source = False
self.value_from_source: bool = True

def restore_stack(self, stack_values):
def restore_stack(self, stack_values, *, value_from_source=True):
prior = self.mutable_side_effects_from_source
self.mutable_side_effects_from_source = True
prev = self.value_from_source
self.value_from_source &= value_from_source
try:
self.foreach(stack_values)
finally:
self.mutable_side_effects_from_source = prior
self.value_from_source = prev

def graph_output_vars(self):
return [x.variable for x in self.graph_outputs.values()]
Expand Down Expand Up @@ -108,7 +112,7 @@ def __call__(self, value, allow_cache=True):
self.top_of_stack = value
return

if value.source is not None and allow_cache:
if value.source is not None and allow_cache and self.value_from_source:
output.extend(value.source.reconstruct(self))
elif value.is_python_constant() and is_safe_constant(
value.as_python_constant()
Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,7 @@ def append_prefix_insts():
pass1 = PyCodegen(tx, root, graph_output_var)
self.side_effects.codegen_hooks(pass1)
self.side_effects.codegen_save_tempvars(pass1)
pass1.restore_stack(stack_values)
pass1.restore_stack(stack_values, value_from_source=not tx.export)
self.side_effects.codegen_update_mutated(pass1)

# one more time now that we have established tempvars
Expand All @@ -923,7 +923,7 @@ def append_prefix_insts():
)
self.side_effects.codegen_hooks(pass2)
self.side_effects.codegen_save_tempvars(pass2)
pass2.restore_stack(stack_values)
pass2.restore_stack(stack_values, value_from_source=not tx.export)
self.side_effects.codegen_update_mutated(pass2)

output = []
Expand Down
7 changes: 3 additions & 4 deletions torch/_export/passes/lift_constant_tensor_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@ def lift_constant_tensor_pass(gm, graph_signature) -> Dict[str, torch.Tensor]:
)
assert fake_mode is not None

first_user_input_loc, first_user_input = None, None
for i, node in enumerate(gm.graph.nodes):
first_user_input_loc, first_user_input = 0, None
for node in gm.graph.nodes:
if node.op == "placeholder" and node.name in graph_signature.user_inputs:
first_user_input = node
first_user_input_loc = i
break
first_user_input_loc += 1

assert first_user_input is not None and first_user_input_loc is not None
tensor_constants = {}

for node in gm.graph.nodes:
Expand Down
39 changes: 14 additions & 25 deletions torch/export/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
from torch._export.wrappers import _wrap_submodules
from torch._functorch.aot_autograd import aot_export_module, GraphSignature
from torch._guards import detect_fake_mode
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
GuardOnDataDependentSymNode,
ShapeEnv,
)
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from torch.utils._sympy.value_ranges import ValueRangeError
Expand Down Expand Up @@ -57,16 +58,8 @@ class ExportDynamoConfig:


def _convert_input_to_fake(gm, args, kwargs):
if (
len(args) == 0
and len(kwargs) == 0
and len(dict(gm.named_parameters())) == 0
and len(dict(gm.named_buffers())) == 0
):
return [], {}, {}, None

params_buffers = _get_params_buffers(gm)
fake_inps: List[torch.Tensor] = []
fake_mode = None
for node in gm.graph.nodes:
if node.op == "placeholder" and "val" in node.meta:
fake_val = node.meta["val"]
Expand All @@ -75,10 +68,11 @@ def _convert_input_to_fake(gm, args, kwargs):

if detected_fake_mode := detect_fake_mode(fake_inps):
fake_mode = detected_fake_mode
else:
fake_mode = FakeTensorMode(shape_env=ShapeEnv())

assert (
fake_mode is not None
), "Cannot find fake_mode attatched to the graph's placeholders."
if len(args) == 0 and len(kwargs) == 0:
return (), {}, params_buffers, fake_mode

count = 0

Expand All @@ -94,10 +88,7 @@ def convert_to_fake(x):
fake_params_buffers = pytree.tree_map_only(
torch.Tensor,
functools.partial(fake_mode.from_tensor, static_shapes=True),
{
**dict(gm.named_parameters(remove_duplicate=False)),
**dict(gm.named_buffers(remove_duplicate=False)),
},
params_buffers,
)
return fake_args, fake_kwargs, fake_params_buffers, fake_mode

Expand Down Expand Up @@ -653,21 +644,19 @@ def _aot_export_strict(gm_torch_level: torch.fx.GraphModule, args, **kwargs):
# The unbacked symint symbols are updated in aot_export
# so we serialize them here instead of inside dynamo

# dynamo_fake_mode can be None if there's no placeholder in gm_torch_level
if dynamo_fake_mode:
gm.meta["inline_constraints"] = {
k: v
for k, v in dynamo_fake_mode.shape_env.runtime_var_to_range.items()
if re.match(r"^[if]\d+$", str(k))
}
gm.meta["inline_constraints"] = {
k: v
for k, v in dynamo_fake_mode.shape_env.runtime_var_to_range.items()
if re.match(r"^[if]\d+$", str(k))
}

num_lifted = next(
(
i
for i, s in enumerate(export_graph_signature.input_specs)
if s.kind == InputKind.USER_INPUT
),
0,
len(export_graph_signature.input_specs),
)
flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs))
range_constraints, equality_constraints = _process_constraints(
Expand Down

0 comments on commit f78f23d

Please sign in to comment.