Skip to content

Commit

Permalink
Add support of multi vector in jni (opensearch-project#1364)
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <heemin@amazon.com>
  • Loading branch information
heemin32 committed Jan 16, 2024
1 parent a1f5c77 commit 3d7e6b7
Show file tree
Hide file tree
Showing 19 changed files with 1,022 additions and 28 deletions.
23 changes: 20 additions & 3 deletions jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,21 @@ if (${CONFIG_FAISS} STREQUAL ON OR ${CONFIG_ALL} STREQUAL ON OR ${CONFIG_TEST} S
set(FAISS_ENABLE_PYTHON OFF)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/external/faiss EXCLUDE_FROM_ALL)

add_library(${TARGET_LIB_FAISS} SHARED ${CMAKE_CURRENT_SOURCE_DIR}/src/org_opensearch_knn_jni_FaissService.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_wrapper.cpp)
add_library(
${TARGET_LIB_FAISS} SHARED
${CMAKE_CURRENT_SOURCE_DIR}/src/org_opensearch_knn_jni_FaissService.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_wrapper.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/knn_extension/faiss/utils/BitSet.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/knn_extension/faiss/MultiVectorResultCollector.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/knn_extension/faiss/MultiVectorResultCollectorFactory.cpp)
target_link_libraries(${TARGET_LIB_FAISS} faiss ${TARGET_LIB_COMMON} OpenMP::OpenMP_CXX)
target_include_directories(${TARGET_LIB_FAISS} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include $ENV{JAVA_HOME}/include $ENV{JAVA_HOME}/include/${JVM_OS_TYPE} ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss)
target_include_directories(${TARGET_LIB_FAISS} PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/include/knn_extension/faiss
${CMAKE_CURRENT_SOURCE_DIR}/include/knn_extension/faiss/utils
$ENV{JAVA_HOME}/include
$ENV{JAVA_HOME}/include/${JVM_OS_TYPE}
${CMAKE_CURRENT_SOURCE_DIR}/external/faiss)
set_target_properties(${TARGET_LIB_FAISS} PROPERTIES SUFFIX ${LIB_EXT})
set_target_properties(${TARGET_LIB_FAISS} PROPERTIES POSITION_INDEPENDENT_CODE ON)

Expand Down Expand Up @@ -198,7 +210,12 @@ if ("${WIN32}" STREQUAL "")
jni_test
tests/faiss_wrapper_test.cpp
tests/nmslib_wrapper_test.cpp
tests/test_util.cpp)
tests/test_util.cpp
tests/knn_extension/faiss/utils/BitSetTest.cpp
tests/knn_extension/faiss/utils/HeapTest.cpp
tests/knn_extension/faiss/MultiVectorResultCollectorTest.cpp
tests/knn_extension/faiss/MultiVectorResultCollectorFactoryTest.cpp
)

target_link_libraries(
jni_test
Expand Down
5 changes: 2 additions & 3 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#define OPENSEARCH_KNN_FAISS_WRAPPER_H

#include "jni_util.h"

#include <jni.h>

namespace knn_jni {
Expand All @@ -38,13 +37,13 @@ namespace knn_jni {
//
// Return an array of KNNQueryResults
jobjectArray QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ);
jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ);

// Execute a query against the index located in memory at indexPointerJ along with Filters
//
// Return an array of KNNQueryResults
jobjectArray QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ);
jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ, jintArray parentIdsJ);

// Free the index located in memory at indexPointerJ
void Free(jlong indexPointer);
Expand Down
69 changes: 69 additions & 0 deletions jni/include/knn_extension/faiss/MultiVectorResultCollector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <faiss/impl/ResultCollector.h>
#include <faiss/MetricType.h>
#include "knn_extension/faiss/utils/BitSet.h"
#include <unordered_map>

namespace os_faiss {

using idx_t = faiss::idx_t;
/**
* Implementation of ResultCollector to support multi vector
*
* Only supports HNSW algorithm
*
* Example:
* When there is two lucene document with two nested fields, the parent_bit_set value of 100100 is provided where
* parent doc ids are 2, and 5. Doc id for nested fields of parent document 2 are 0, and 1. Doc id for nested fields
* of parent document 5 are 3, and 4. For faiss, only nested fields are stored. Therefore corresponding doc ids for
* nested fields 0, 1, 3, 4 is 0, 1, 2, 3 in faiss. This mapping data is stored in id_map parameter.
*
* When collect method is called
* 1. It switches from faiss id to lucene id and look for its parent id.
* 2. See if the parent id already exist in heap using either parent_id_to_id or parent_id_to_index.
* 3. If it does not exist, add the parent id and distance value in the heap(bh_ids, bh_val) and update parent_id_to_id, and parent_id_to_index.
* 4. If it does exist, update the distance value(bh_val), parent_id_to_id, and parent_id_to_index.
*
* When post_process method is called
* 1. Convert lucene parent ID to faiss doc ID using parent_id_to_id
*/
struct MultiVectorResultCollector:faiss::ResultCollector {
// BitSet of lucene parent doc ID
const BitSet* parent_bit_set;

// Mapping data from Faiss doc ID to Lucene doc ID
const std::vector<int64_t>* id_map;

// Lucene parent doc ID to to Faiss doc ID
// Lucene parent doc ID to index in heap(bh_val, bh_ids)
std::unordered_map<idx_t, idx_t> parent_id_to_id;
std::unordered_map<idx_t, size_t> parent_id_to_index;
MultiVectorResultCollector(const BitSet* parent_bit_set, const std::vector<int64_t>* id_map);

/**
*
* @param k max size of bh_val, and bh_ids
* @param nres number of results in bh_val, and bh_ids
* @param bh_val binary heap storing values (For this case distance from query to result)
* @param bh_ids binary heap storing document IDs
* @param val a new value to add in bh_val
* @param ids a new doc id to add in bh_ids
*/
void collect(
int k,
int& nres,
float* bh_val,
int64_t* bh_ids,
float val,
int64_t ids) override;
void post_process(int64_t nres, int64_t* bh_ids) override;
};

} // namespace os_faiss

Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <faiss/impl/ResultCollectorFactory.h>
#include "knn_extension/faiss/utils/BitSet.h"

namespace os_faiss {
/**
* Create MultiVectorResultCollector for single query request
*
* Creating new collector is required because MultiVectorResultCollector has instance variables
* which should be isolated for each query.
*/
struct MultiVectorResultCollectorFactory:faiss::ResultCollectorFactory {
BitSet* parent_bit_set;

MultiVectorResultCollectorFactory(BitSet* parent_bit_set);
faiss::ResultCollector* new_collector() override;
void delete_collector(faiss::ResultCollector* resultCollector) override;
};

} // namespace os_faiss
51 changes: 51 additions & 0 deletions jni/include/knn_extension/faiss/utils/BitSet.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <faiss/MetricType.h>
#include <faiss/impl/platform_macros.h>
#include <limits>

using idx_t = faiss::idx_t;

struct BitSet {
const int NO_MORE_DOCS = std::numeric_limits<int>::max();
/**
* Returns the index of the first set bit starting at the index specified.
* NO_MORE_DOCS is returned if there are no more set bits.
*/
virtual idx_t next_set_bit(idx_t index) const = 0;
virtual ~BitSet() = default;
};


/**
* BitSet of fixed length (numBits), implemented using an array of unit64.
* See https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java
*
* Here a block is 64 bit. However, for simplicity let's assume its size is 8 bits.
* Then, if have an array of 3, 7, and 10, it will be represented in bitmap as follow.
* [0] [1]
* bitmap: 10001000 00000100
*
* for next_set_bit call with 4
* 1. it looks for bitmap[0]
* 2. bitmap[0] >> 4
* 3. count trailing zero of the result from step 2 which is 3
* 4. return 4(current index) + 3(result from step 3)
*/
struct FixedBitSet : public BitSet {
// Length of bitmap
size_t numBits;

// Pointer to an array of uint64_t
// Using uint64_t to leverage function __builtin_ctzll which is defined in faiss/impl/platform_macros.h
uint64_t* bitmap;

FixedBitSet(const int* int_array, const int length);
idx_t next_set_bit(idx_t index) const;
~FixedBitSet();
};
Loading

0 comments on commit 3d7e6b7

Please sign in to comment.