Skip to content

Commit

Permalink
enhance slice layer
Browse files Browse the repository at this point in the history
refactor the code for parsing Slice layer
add test for Slice layer
let 'begin' and 'end' resize to dims
add opset message comment
  • Loading branch information
WanliZhong committed Oct 1, 2022
1 parent 04ebedb commit 4557971
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 49 deletions.
118 changes: 69 additions & 49 deletions modules/dnn/src/onnx/onnx_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1326,72 +1326,59 @@ void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::Node

void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
int axis = 0;
std::vector<int> begin;
std::vector<int> end;
MatShape inpShape = outShapes[node_proto.input(0)];
int dims = inpShape.size();
std::vector<int> begin(dims, 0);
std::vector<int> end(dims, INT_MAX);
std::vector<int> steps;
int inp_size = node_proto.input_size();
int axis = 0;
bool has_axes = false;
DictValue starts_, ends_, axes_, steps_;

// opset = 1
if (inp_size == 1)
{
if (layerParams.has("axes")) {
DictValue axes = layerParams.get("axes");
for (int i = 1; i < axes.size(); ++i) {
CV_Assert(axes.get<int>(i - 1) == axes.get<int>(i) - 1);
}
axis = axes.get<int>(0);
}

DictValue starts = layerParams.get("starts");
DictValue ends = layerParams.get("ends");
CV_Assert(starts.size() == ends.size());

if (axis > 0) {
CV_CheckLE(axis, 1024, "Slice layer can't have more than 1024 axes"); // arbitrary limit
begin.resize(axis, 0);
end.resize(axis, INT_MAX);
}
for (int i = 0; i < starts.size(); ++i)
starts_ = layerParams.get("starts");
ends_ = layerParams.get("ends");
CV_Assert(starts_.size() == ends_.size());
if (layerParams.has("axes"))
{
begin.push_back(starts.get<int>(i));
end.push_back(ends.get<int>(i));
axes_ = layerParams.get("axes");
CV_Assert(axes_.size() == starts_.size());
axis = axes_.getIntValue(0) < 0 ? axes_.getIntValue(0) + dims : axes_.getIntValue(0);
has_axes = true;
}
} else { // inp_size > 1
}
// opset > 1
else
{
CV_Assert(inp_size >= 3);
for (int i = 1; i < inp_size; i++) {
for (int i = 1; i < inp_size; ++i)
{
CV_Assert(constBlobs.find(node_proto.input(i)) != constBlobs.end());
}
Mat start_blob = getBlob(node_proto, 1);
Mat end_blob = getBlob(node_proto, 2);
Mat end_blob = getBlob(node_proto, 2);
CV_Assert(start_blob.total() == end_blob.total());
starts_ = DictValue::arrayInt(start_blob.begin<int>(), start_blob.total());
ends_ = DictValue::arrayInt(end_blob.begin<int>(), end_blob.total());

if (inp_size > 3) {
if (inp_size > 3)
{
Mat axes_blob = getBlob(node_proto, 3);
const int* axes = (int*)axes_blob.data;
for (int i = 1; i < axes_blob.total(); ++i) {
CV_Assert(axes[i - 1] == axes[i] - 1);
}
axis = axes[0];
}

const int* starts = start_blob.ptr<int>();
const int* ends = end_blob.ptr<int>();
if (axis > 0) {
begin.resize(axis, 0);
end.resize(axis, INT_MAX);
CV_Assert(axes_blob.total() == start_blob.total());
axes_ = DictValue::arrayInt(axes_blob.begin<int>(), axes_blob.total());
axis = axes_.getIntValue(0) < 0 ? axes_.getIntValue(0) + dims : axes_.getIntValue(0);
has_axes = true;
}
std::copy(starts, starts + start_blob.total(), std::back_inserter(begin));
std::copy(ends, ends + end_blob.total(), std::back_inserter(end));

if (inp_size == 5) {
CV_Assert(constBlobs.find(node_proto.input(4)) != constBlobs.end());
if (inp_size == 5)
{
Mat step_blob = getBlob(node_proto, 4);
const int* steps_ptr = step_blob.ptr<int>();

if (axis > 0)
steps.resize(axis, 1);

std::copy(steps_ptr, steps_ptr + step_blob.total(), std::back_inserter(steps));
CV_Assert(step_blob.total() == start_blob.total());
steps_ = DictValue::arrayInt(step_blob.begin<int>(), step_blob.total());
steps.resize(dims, 1);

// Very strange application for Slice op with tensor reversing.
// We just workaround it for 2d constants.
Expand All @@ -1411,12 +1398,45 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP
}
}
}

if (!has_axes)
{
// make a default axes [0, 1, 2...]
Mat axes_tmp(1, starts_.size(), CV_32S);
std::iota(axes_tmp.begin<int>(), axes_tmp.end<int>(), 0);
axes_ = DictValue::arrayInt(axes_tmp.begin<int>(), axes_tmp.total());
}

int cur_axe;
std::vector<bool> flag(dims, false);
Mat axes(1, starts_.size(), CV_32S);
auto axes_ptr = axes.ptr<int>();
// resize begin and end
for (int i = 0; i < axes_.size(); ++i)
{
// dims should be added to the negative axes
cur_axe = axes_.getIntValue(i) < 0 ? axes_.getIntValue(i) + dims : axes_.getIntValue(i);
CV_CheckGE(cur_axe, 0, "Axes should be grater or equal to '-dims'.");
CV_CheckLT(cur_axe, dims, "Axes should be less than 'dim'.");
CV_CheckEQ(flag[cur_axe], false, "Axes shouldn't have duplicated values.");
flag[cur_axe] = true;
// change axis to the minimum axe
if (cur_axe < axis) axis = cur_axe;
axes_ptr[i] = cur_axe;
begin[cur_axe] = starts_.getIntValue(i);
end[cur_axe] = ends_.getIntValue(i);
}

layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size()));
layerParams.set("end", DictValue::arrayInt(&end[0], end.size()));
layerParams.set("axis", axis);

if (!steps.empty())
{
for (int i = 0; i < axes.total(); ++i)
steps[axes_ptr[i]] = steps_.getIntValue(i);
layerParams.set("steps", DictValue::arrayInt(&steps[0], steps.size()));
}

if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
{
Expand Down
14 changes: 14 additions & 0 deletions modules/dnn/test/test_onnx_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1172,6 +1172,20 @@ TEST_P(Test_ONNX_layers, Slice_Steps_5DInput)
testONNXModels("slice_opset_11_steps_5d");
}

TEST_P(Test_ONNX_layers, Slice_Nonseq_Axes)
{
testONNXModels("slice_nonseq_axes");
testONNXModels("slice_nonseq_axes_steps");
testONNXModels("slice_nonseq_miss_axes_steps");
}

TEST_P(Test_ONNX_layers, Slice_Neg_Axes)
{
testONNXModels("slice_neg_axes");
testONNXModels("slice_neg_axes_steps");
testONNXModels("slice_neg_miss_axes_steps");
}

TEST_P(Test_ONNX_layers, Softmax)
{
testONNXModels("softmax");
Expand Down

0 comments on commit 4557971

Please sign in to comment.