Skip to content

Commit

Permalink
[PyTorch] Add input numel assert for minimal arrayref interface (pyto…
Browse files Browse the repository at this point in the history
…rch#113577)

We currently have no shape checking on CPU IIUC. Now we at least do numel checking for the minimal arrayref interface.

Differential Revision: [D51165703](https://our.internmc.facebook.com/intern/diff/D51165703/)

Pull Request resolved: pytorch#113577
Approved by: https://github.com/chenyang78, https://github.com/jansel
ghstack dependencies: pytorch#112800
  • Loading branch information
swolchok authored and dmenig committed Dec 21, 2023
1 parent 233667c commit 86098c6
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
8 changes: 8 additions & 0 deletions torch/_inductor/codegen/aoti_runtime/implementation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,5 +236,13 @@ RAIIAtenTensorHandle convert_arrayref_tensor_to_tensor(
return art.expensiveCopyToTensor();
}

template <typename T>
void assert_numel(const ArrayRefTensor<T>& tensor, int64_t numel) {
if (tensor.numel() != numel) {
std::stringstream err;
err << "incorrect numel for input tensor. expected " << numel << ", got " << tensor.numel();
throw std::runtime_error(err.str());
}
}
} // namespace aot_inductor
} // namespace torch
1 change: 1 addition & 0 deletions torch/_inductor/codegen/aoti_runtime/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <torch/csrc/inductor/aoti_runtime/model_container.h>

#include <iostream>
#include <sstream>
#include <stdexcept>
#include <vector>

Expand Down
16 changes: 15 additions & 1 deletion torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1639,12 +1639,26 @@ def write_wrapper_decl(self):
self.codegen_inputs(self.prefix, V.graph.graph_inputs)

if V.graph.aot_mode:
if not config.use_minimal_arrayref_interface:
if config.use_minimal_arrayref_interface:
# TODO: input shape checking for regular tensor interface as well?
self.codegen_input_numel_asserts()
else:
self.prefix.writeline("inputs.clear();")
self.prefix.writeline(
"auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());"
)

def codegen_input_numel_asserts(self):
for name, buf in V.graph.graph_inputs.items():
if isinstance(buf, sympy.Expr):
continue

# comparing strides for 0 size tensor is tricky. Ignore them for now.
if sympy_product(buf.get_size()) == 0:
continue
numel = buf.get_numel()
self.prefix.writeline(f"assert_numel({name}, {numel});")

def codegen_input_size_var_decl(self, code: IndentedBuffer, name):
if config.aot_inductor.abi_compatible:
code.writeline(f"int64_t* {name}_size;")
Expand Down

0 comments on commit 86098c6

Please sign in to comment.