Skip to content

Commit

Permalink
Consider storage_changed for assigning alias_of_input in aot_autograd…
Browse files Browse the repository at this point in the history
… when computing differentiable outputs that alias each other (pytorch#115315)

Pull Request resolved: pytorch#115315
Approved by: https://github.com/bdhirsh
  • Loading branch information
voznesenskym authored and pytorchmergebot committed Dec 12, 2023
1 parent 946de1c commit 76ced0d
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 4 deletions.
3 changes: 2 additions & 1 deletion aten/src/ATen/FunctionalTensorWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const
view_value.device()
),
value_(view_value),
is_multi_output_view_(base->is_multi_output_view_ || meta.is_multi_output)
is_multi_output_view_(base->is_multi_output_view_ || meta.is_multi_output),
was_storage_changed_(base->was_storage_changed_)
{
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
Expand Down
4 changes: 2 additions & 2 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -3867,7 +3867,7 @@ def func3(x, y):
z = x
x.data = y
y.data = torch.zeros([0])
return x is z
return torch.tensor(x is z)

for backend in ["eager", "aot_eager", "inductor"]:
for func in [func1, func2, func3]:
Expand All @@ -3893,7 +3893,7 @@ def func3(x, y):
out_compiled = compiled_fn(compiled_a, compiled_b)
self.assertEqual(eager_a, compiled_a)
self.assertEqual(eager_b, compiled_b)
self.assertEqual(out_eager, out_compiled)
self.assertTrue(torch.equal(out_eager, out_compiled))

# func1 hits a leaf Variable that requires grad is being used in an in-place operation
if requires_grad:
Expand Down
34 changes: 33 additions & 1 deletion torch/_functorch/_aot_autograd/collect_metadata_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,36 @@ def inner(*flat_args):
# maps the id of an intermediate base to its index in the output of the compiled forward
intermediate_base_tensor_id_to_output_idx: Dict[int, int] = {}
intermediate_bases: List[torch.Tensor] = []
# Why Do We Care If Storage Changed?
# It's important to understand the implications of storage changes in complex scenarios. Take this example:
#
# def f(x):
# x_storage = x.untyped_storage()
# non_leaf_tensor = torch.ones(4, requires_grad=True).clone()
#
# # Using no_grad() and _unsafe_preserve_version_counter to simulate the .data = operation
# with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(x):
# x.set_(non_leaf_tensor.untyped_storage())
#
# out = x.view(-1)
#
# # Restoring x to its original storage, again simulating .data = operation
# with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(x):
# x.set_(x_storage)
#
# return out
#
# In this scenario, 'x' and 'out' have different shapes and are stored at different memory addresses, aka no aliasing.
# However, due to how set_() and more specificlaly, set is functionalized, is defined to preserve eager semantics,
# the autograd engine mistakenly assumes that 'x' and 'out' are aliased, treating 'x' as 'out._base'.
# This misinterpretation leads to an 'alias_of_input' flag, causing an unnecessary as_strided() call to be generated,
# which could lead to issues later in the code.
for o in flat_f_outs:
functional_tensor_storage_changed = isinstance(
o, FunctionalTensor
) and torch._functionalize_was_storage_changed( # type: ignore[attr-defined]
o.elem
)
curr_storage = (
None
if not isinstance(o, torch.Tensor)
Expand Down Expand Up @@ -338,7 +367,10 @@ def inner(*flat_args):
):
output_type = OutputType.custom_function_view
base_idx = None
elif curr_storage in inp_storage_refs:
elif (
curr_storage in inp_storage_refs
and not functional_tensor_storage_changed
):
base_idx = inp_storage_refs[curr_storage]
is_input_tensor = id(o) in inp_tensor_ids
num_aliased_outs = out_tensor_alias_counts[curr_storage]
Expand Down

0 comments on commit 76ced0d

Please sign in to comment.