Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug/Remove Squeeze Panic for Multiple Dimensions #2035

Merged
merged 10 commits into from
Jul 22, 2024

Conversation

agelas
Copy link
Contributor

@agelas agelas commented Jul 17, 2024

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Fixes #2033

Changes

Removes panic if user passes more than 1 axes to a Squeeze node in their ONNX graph, should've been part of #1779.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing the issue!

If providing multiple axes is supported we should probably add a test to cover the use case 🤔 What do you think?

@agelas
Copy link
Contributor Author

agelas commented Jul 17, 2024

I think that would definitely be helpful for catching things like this. I know burn-import got split up a bit- is there a place now for handling a test like that? Right now, the only test for squeeze_dims is in crates/burn-tensor/src/tests/ops/squeeze.rs, so only on the burn side but not ONNX conversion.

Copy link

codecov bot commented Jul 17, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 84.24%. Comparing base (9804bf8) to head (2d6f51f).
Report is 4 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2035      +/-   ##
==========================================
- Coverage   84.25%   84.24%   -0.01%     
==========================================
  Files         846      852       +6     
  Lines      105456   105640     +184     
==========================================
+ Hits        88853    89000     +147     
- Misses      16603    16640      +37     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@laggui
Copy link
Member

laggui commented Jul 18, 2024

On the ONNX side, we generate the ONNX files to test with this: https://github.com/tracel-ai/burn/blob/main/crates/burn-import/onnx-tests/tests/squeeze/squeeze.py

Right now it only covers a single axis, but we could provide an input with multiple axes to squeeze. For example [3, 4, 1, 5, 1] and torch.squeeze(x, (2, 4)). So the included test would cover both use cases.

@agelas
Copy link
Contributor Author

agelas commented Jul 19, 2024

@laggui Hmm, so that behavior doesn't look supported on the torch->ONNX conversion side. When I use your example with torch.squeeze(x, (2, 4)), I get this error:

Traceback (most recent call last):
  File "/burn/crates/burn-import/onnx-tests/tests/squeeze/squeeze.py", line 46, in <module>
    main()
  File "/burn/crates/burn-import/onnx-tests/tests/squeeze/squeeze.py", line 32, in main
    torch.onnx.export(model, test_input, "squeeze_opset16.onnx", verbose=False, opset_version=16)
  File "/burn/myenv/lib/python3.10/site-packages/torch/onnx/utils.py", line 516, in export
    _export(
  File "/burn/myenv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1596, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/burn/myenv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1139, in _model_to_graph
    graph = _optimize_graph(
  File "/burn/myenv/lib/python3.10/site-packages/torch/onnx/utils.py", line 677, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
  File "/burn/myenv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1940, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
  File "/burn/myenv/lib/python3.10/site-packages/torch/onnx/symbolic_opset11.py", line 930, in squeeze
    dim = symbolic_helper._get_const(dim, "i", "dim")
  File "/burn/myenv/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py", line 178, in _get_const
    return _parse_arg(value, desc)
  File "/burn/myenv/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py", line 83, in _parse_arg
    return int(node_val)
ValueError: only one element tensors can be converted to Python scalars

If we just had an ONNX graph though, I think it might be interpreted correctly on the burn side.

@laggui
Copy link
Member

laggui commented Jul 19, 2024

Oh, well I guess you're right then. If we want to add a test we'll have to construct the graph manually. I could provide a script for that later when I get some time.

/edit: this should do the trick I believe

import onnx
from onnx import helper, TensorProto

input = helper.make_tensor_value_info("input", TensorProto.FLOAT, [3, 4, 1, 5, 1])
output = helper.make_tensor_value_info("output", TensorProto.FLOAT, [3, 4, 5])
squeeze = helper.make_node(op_type="Squeeze", inputs=["input", "axes"], outputs=["output"], name="SqueezeOp")
axes = helper.make_tensor("axes", TensorProto.INT64, dims=[2], vals=[2, 4])
graph = helper.make_graph([squeeze], "SqueezeMultiple", [input], [output], [axes])
opset = helper.make_opsetid("", 13)
m = helper.make_model(graph, opset_imports=[opset])

onnx.checker.check_model(m, full_check=True)
onnx.save(m, "squeeze_multiple.onnx")

@agelas
Copy link
Contributor Author

agelas commented Jul 19, 2024

Thanks for the script. I added that to squeeze.py and got this, so it looks all good:

image

Since it's not using torch at all, do we need to add it to onnx_tests.rs?
EDIT: stupid question

@agelas
Copy link
Contributor Author

agelas commented Jul 20, 2024

Ok, fingers crossed everything should be working now. Testing the model you constructed also lead me to finding another bug with how output dimensions were calculated, so adding that test was a huge plus.

@agelas agelas requested a review from laggui July 20, 2024 00:20
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for adding the multiple dim support and test cases 🙏

Copy link
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@antimora antimora merged commit 0bbc1ed into tracel-ai:main Jul 22, 2024
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Onnx Squeeze to accept a list of axes
3 participants