forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Export] Support ser/des test on existing cases (pytorch#115413)
Summary: Similar as pytorch#115399 Test Plan: ``` $ python test/export/test_serdes.py ... Ran 72 tests in 29.097s OK (expected failures=13) ``` Pull Request resolved: pytorch#115413 Approved by: https://github.com/tugsbayasgalan ghstack dependencies: pytorch#115399, pytorch#115402
- Loading branch information
1 parent
b0c7dd4
commit 4744359
Showing
3 changed files
with
69 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Owner(s): ["module: dynamo"] | ||
|
||
import io | ||
|
||
import test_export | ||
import testing | ||
|
||
from torch.export import export, load, save | ||
|
||
test_classes = {} | ||
|
||
|
||
def mocked_serder_export(*args, **kwargs): | ||
ep = export(*args, **kwargs) | ||
buffer = io.BytesIO() | ||
save(ep, buffer) | ||
buffer.seek(0) | ||
loaded_ep = load(buffer) | ||
return loaded_ep | ||
|
||
|
||
def make_dynamic_cls(cls): | ||
suffix = "_serdes" | ||
|
||
cls_prefix = "SerDesExport" | ||
|
||
test_class = testing.make_test_cls_with_mocked_export( | ||
cls, | ||
cls_prefix, | ||
suffix, | ||
mocked_serder_export, | ||
xfail_prop="_expected_failure_serdes", | ||
) | ||
|
||
test_classes[test_class.__name__] = test_class | ||
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING | ||
globals()[test_class.__name__] = test_class | ||
test_class.__module__ = __name__ | ||
return test_class | ||
|
||
|
||
tests = [ | ||
test_export.TestDynamismExpression, | ||
test_export.TestExport, | ||
] | ||
for test in tests: | ||
make_dynamic_cls(test) | ||
del test | ||
|
||
if __name__ == "__main__": | ||
from torch._dynamo.test_case import run_tests | ||
|
||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters