Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] Fix a bmm related bug #532

Closed
wants to merge 1 commit into from
Closed

Conversation

desertfire
Copy link
Contributor

Summary: The bug is exposed with pytorch_struct training run. We need to
be conservative on the optimization for bs=1 bmm, when any operand is not consecutive.

Summary: The bug is exposed with pytorch_struct training run. We need to
be conservative on the optimization for bs=1 bmm, when any operand is not consecutive.
b3 == 1
and is_contiguous_storage_and_layout(a)
and is_contiguous_storage_and_layout(b)
):
# convert to normal mm
data = MatrixMultiply(
layout=output_layout.as_fixed(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jansel I guess another option is to generate a ReinterpretView instead of View when the batch size is 1, but I am not sure if that is safe to do here.

if b3 == 1:
if (
b3 == 1
and is_contiguous_storage_and_layout(a)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't look right, MatrixMultiply should be able to handle transposed inputs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is because View.create does not return a ReinterpretView because of the checking at

if is_contiguous_storage_and_layout(x):
, and thus later failed an assertion at
assert isinstance(x, (Buffer, ReinterpretView)), x
.

I think making a special case in View.create might work. @jansel , any suggestion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hrm, in this case perhaps we should be treating it as contiguous, and ignore size=1 dims in contiguous checks. The bs=1 doesn't actually matter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

size=1 is already not checked, but one (or both) of the 2d matrices resulting from removing batch dimension might not be contiguous, and thus View.create returns View and not ReinterpretView. I think we can return ReinterpretView in more cases by specifying proper strides in the new_layout

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should extend SqueezeView.create to accept a specific dimension, and use Squeezeview.create instead of View.create below.

Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @ngimel here we should be able to map bmm to mm in the transposed case. It might be a bit annoying to do that because not all of the view types work with extern kernels without adding a copy.

Accepting because this fixes a bug, and converting bmm to mm is a minor optimization that it is likely ok to miss.

@ngimel
Copy link
Contributor

ngimel commented Jul 10, 2022

Here's what I had in mind #548, it still allows bmm -> mm optimization for discontiguous inputs, and won't cause more copies than bmm itself would have caused.

@jansel jansel closed this Jul 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants