Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Throw on data dependent outputs #1163

Merged
merged 21 commits into from
Sep 13, 2022

Conversation

eellison
Copy link
Contributor

@eellison eellison commented Sep 8, 2022

We need to flip on pytorch/pytorch#83567 for correctness. This patch allows us to continue to trace through .item() by catching the exception.

It also switches the optimizers to use fake_tensors=False because as explained in pytorch/pytorch#84597 and pytorch/pytorch#93660, the current wrapping of .item() for those cases only defers the _local_scaler_dense call until torch dispatch layer, where fakes tensors have to throw. Those issues lay out a path forward for not graph breaking, but we should get this patch in even with the graph breaks because it's needed for correctness.

This is ready for review but it's going to take another day before the changes in pytorch are available in the nightly and the tests will pass.

@eellison eellison changed the title Prepare for throwing on data dependent outputs Throw on data dependent outputs Sep 8, 2022
Copy link
Contributor Author

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@voznesenskym voznesenskym left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable

torch.optim.ASGD, exp_frame_cnt=(0 if sys.version_info < (3, 8) else 6)
)

# Fails without fake tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Can you file an issue for this? Cc @mlazos

@@ -133,6 +133,7 @@ def fn(x):
res2 = opt_fn(x)
self.assertTrue(same(res1, res2))

@patch.object(torchdynamo.config, "fake_tensor_propagation", False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious why this failed? This is one of the test case copied from Pytorch CI, where fake tensor is enabled IIRC.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> x = torch.tensor(0)
>>> torch.randint(x, x, x)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: randint() received an invalid combination of arguments - got (Tensor, Tensor, Tensor), but expected one of:
 * (int high, tuple of ints size, *, torch.Generator generator, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (int high, tuple of ints size, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (int low, int high, tuple of ints size, *, torch.Generator generator, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (int low, int high, tuple of ints size, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)

the int->torch.tensor(0) conversion causes an error

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants