Skip to content

Commit

Permalink
[Export] Support ser/des test on existing cases (pytorch#115413)
Browse files Browse the repository at this point in the history
Summary:
Similar as pytorch#115399

Test Plan:
```
$ python test/export/test_serdes.py
...
Ran 72 tests in 29.097s

OK (expected failures=13)
```
Pull Request resolved: pytorch#115413
Approved by: https://github.com/tugsbayasgalan
ghstack dependencies: pytorch#115399, pytorch#115402
  • Loading branch information
andrewlee302 authored and pytorchmergebot committed Dec 13, 2023
1 parent b0c7dd4 commit 4744359
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 2 deletions.
12 changes: 10 additions & 2 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def kw_func(arg1, arg2, a, b):
kwargs = {"a": {"kw1": torch.ones(2, 3), "kw2": torch.ones(3, 4)}, "b": [torch.ones(2, 3), torch.ones(3, 4)]}
self._test_export_same_as_eager(kw_func, args, kwargs)

@testing.expectedFailureSerDer
@testing.expectedFailureRetraceability
@testing.expectedFailureNonStrict
def test_export_func_with_default_kwargs(self):
Expand Down Expand Up @@ -359,6 +360,7 @@ def kw_func(arg1, arg2, *args, kw1, kw2, **kwargs):
"kw3": (torch.ones(2, 3), torch.ones(3, 4)), "kw4": torch.ones(3, 4)}
self._test_export_same_as_eager(kw_func, args, kwargs)

@testing.expectedFailureSerDer
@testing.expectedFailureNonStrict
def test_linear_conv(self):

Expand Down Expand Up @@ -798,7 +800,8 @@ def forward(self, x):
self.assertEqual(params[1].shape, [1]) # bias

def test_buffer_util(self):
ep = export(torch.nn.BatchNorm2d(100, affine=False), (torch.ones(20, 100, 35, 45), ))
ep = export(torch.nn.BatchNorm2d(100, affine=False), (torch.ones(20, 100, 35, 45)
, ))
num_buffer = 0
buffer = []

Expand Down Expand Up @@ -849,6 +852,7 @@ def _patch_config(kwargs):
):
_ = export(mod, inp)

@testing.expectedFailureSerDer
@testing.expectedFailureNonStrict
def test_module(self):

Expand Down Expand Up @@ -885,6 +889,7 @@ def forward(self, x):
self.assertTrue(torch.allclose(ep(*inp_test)[0], ep_rexported(*inp_test)[0]))
self.assertTrue(torch.allclose(ep(*inp_test)[1], ep_rexported(*inp_test)[1]))

@testing.expectedFailureSerDer
@testing.expectedFailureNonStrict
def test_module_with_dict_container_inp_out(self):

Expand Down Expand Up @@ -1067,6 +1072,7 @@ def case_5(x, y):
)
)

@testing.expectedFailureSerDer
@testing.expectedFailureNonStrict
def test_mixed_input(self):
def func(a, b, alpha: int):
Expand Down Expand Up @@ -1117,7 +1123,6 @@ def f(x):
).run(ep.graph_module.code)

def test_to_module_with_mutated_buffer(self):

class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -1212,6 +1217,7 @@ def f(x, y):
):
_ = exported(torch.ones(7, 5), 6.0)

@testing.expectedFailureSerDer
@testing.expectedFailureNonStrict
def test_runtime_assert_for_prm_str(self):

Expand Down Expand Up @@ -1313,6 +1319,7 @@ def forward(self, x):
with self.assertRaisesRegex(RuntimeError, "shape\[1\] is specialized at 5"):
torch.export.export(exported_v2, (torch.randn(2, 2),))

@testing.expectedFailureSerDer
@testing.expectedFailureNonStrict
def test_retrace_graph_level_meta_preservation(self):
class Foo(torch.nn.Module):
Expand Down Expand Up @@ -1583,6 +1590,7 @@ def f(x):
inp = torch.randn(2)
self.assertTrue(torch.allclose(ep(inp), torch.nonzero(inp)))

@testing.expectedFailureSerDer
@testing.expectedFailureRetraceability
@testing.expectedFailureNonStrict
def test_redundant_asserts(self):
Expand Down
53 changes: 53 additions & 0 deletions test/export/test_serdes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Owner(s): ["module: dynamo"]

import io

import test_export
import testing

from torch.export import export, load, save

test_classes = {}


def mocked_serder_export(*args, **kwargs):
ep = export(*args, **kwargs)
buffer = io.BytesIO()
save(ep, buffer)
buffer.seek(0)
loaded_ep = load(buffer)
return loaded_ep


def make_dynamic_cls(cls):
suffix = "_serdes"

cls_prefix = "SerDesExport"

test_class = testing.make_test_cls_with_mocked_export(
cls,
cls_prefix,
suffix,
mocked_serder_export,
xfail_prop="_expected_failure_serdes",
)

test_classes[test_class.__name__] = test_class
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
globals()[test_class.__name__] = test_class
test_class.__module__ = __name__
return test_class


tests = [
test_export.TestDynamismExpression,
test_export.TestExport,
]
for test in tests:
make_dynamic_cls(test)
del test

if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

run_tests()
6 changes: 6 additions & 0 deletions test/export/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,9 @@ def expectedFailureNonStrict(fn):
def expectedFailureRetraceability(fn):
fn._expected_failure_retrace = True
return fn


# Controls tests generated in test/export/test_serdes.py
def expectedFailureSerDer(fn):
fn._expected_failure_serdes = True
return fn

0 comments on commit 4744359

Please sign in to comment.