-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor] Fix a bmm related bug #532
Conversation
Summary: The bug is exposed with pytorch_struct training run. We need to be conservative on the optimization for bs=1 bmm, when any operand is not consecutive.
b3 == 1 | ||
and is_contiguous_storage_and_layout(a) | ||
and is_contiguous_storage_and_layout(b) | ||
): | ||
# convert to normal mm | ||
data = MatrixMultiply( | ||
layout=output_layout.as_fixed(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jansel I guess another option is to generate a ReinterpretView instead of View when the batch size is 1, but I am not sure if that is safe to do here.
if b3 == 1: | ||
if ( | ||
b3 == 1 | ||
and is_contiguous_storage_and_layout(a) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesn't look right, MatrixMultiply should be able to handle transposed inputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is because View.create
does not return a ReinterpretView
because of the checking at
torchdynamo/torchinductor/ir.py
Line 695 in 044e8bd
if is_contiguous_storage_and_layout(x): |
torchdynamo/torchinductor/ir.py
Line 1570 in 044e8bd
assert isinstance(x, (Buffer, ReinterpretView)), x |
I think making a special case in View.create
might work. @jansel , any suggestion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hrm, in this case perhaps we should be treating it as contiguous, and ignore size=1 dims in contiguous checks. The bs=1
doesn't actually matter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
size=1
is already not checked, but one (or both) of the 2d matrices resulting from removing batch dimension might not be contiguous, and thus View.create
returns View
and not ReinterpretView
. I think we can return ReinterpretView
in more cases by specifying proper strides in the new_layout
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should extend SqueezeView.create to accept a specific dimension, and use Squeezeview.create
instead of View.create
below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @ngimel here we should be able to map bmm to mm in the transposed case. It might be a bit annoying to do that because not all of the view types work with extern kernels without adding a copy.
Accepting because this fixes a bug, and converting bmm
to mm
is a minor optimization that it is likely ok to miss.
Here's what I had in mind #548, it still allows bmm -> mm optimization for discontiguous inputs, and won't cause more copies than bmm itself would have caused. |
Summary: The bug is exposed with pytorch_struct training run. We need to
be conservative on the optimization for bs=1 bmm, when any operand is not consecutive.