-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Throw on data dependent outputs #1163
Throw on data dependent outputs #1163
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is being tested here https://github.com/pytorch/torchdynamo/blob/main/test/test_repros.py#L887
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable
torch.optim.ASGD, exp_frame_cnt=(0 if sys.version_info < (3, 8) else 6) | ||
) | ||
|
||
# Fails without fake tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. Can you file an issue for this? Cc @mlazos
@@ -133,6 +133,7 @@ def fn(x): | |||
res2 = opt_fn(x) | |||
self.assertTrue(same(res1, res2)) | |||
|
|||
@patch.object(torchdynamo.config, "fake_tensor_propagation", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious why this failed? This is one of the test case copied from Pytorch CI, where fake tensor is enabled IIRC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
>>> x = torch.tensor(0)
>>> torch.randint(x, x, x)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: randint() received an invalid combination of arguments - got (Tensor, Tensor, Tensor), but expected one of:
* (int high, tuple of ints size, *, torch.Generator generator, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
* (int high, tuple of ints size, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
* (int low, int high, tuple of ints size, *, torch.Generator generator, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
* (int low, int high, tuple of ints size, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
the int->torch.tensor(0) conversion causes an error
…ta_depedent_outputs_changes
…ta_depedent_outputs_changes
We need to flip on pytorch/pytorch#83567 for correctness. This patch allows us to continue to trace through
.item()
by catching the exception.It also switches the optimizers to use
fake_tensors=False
because as explained in pytorch/pytorch#84597 and pytorch/pytorch#93660, the current wrapping of.item()
for those cases only defers the_local_scaler_dense
call until torch dispatch layer, where fakes tensors have to throw. Those issues lay out a path forward for not graph breaking, but we should get this patch in even with the graph breaks because it's needed for correctness.This is ready for review but it's going to take another day before the changes in pytorch are available in the nightly and the tests will pass.