Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Elasticsearch vector database #4188

Merged
merged 18 commits into from
May 13, 2024
Prev Previous commit
Next Next commit
format
Signed-off-by: cmuhao <sduxuhao@gmail.com>
  • Loading branch information
HaoXuAI committed May 12, 2024
commit 06890213249dd4db3f99a49f20f56fd4f4a0db0c
84 changes: 46 additions & 38 deletions sdk/python/feast/infra/online_stores/contrib/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,12 @@ def _get_client(self, config: RepoConfig) -> Elasticsearch:
online_store_config = config.online_store
assert isinstance(online_store_config, ElasticSearchOnlineStoreConfig)

user = online_store_config.user if online_store_config.user is not None else ''
password = online_store_config.password if online_store_config.password is not None else ''
user = online_store_config.user if online_store_config.user is not None else ""
password = (
online_store_config.password
if online_store_config.password is not None
else ""
)

if self._client:
return self._client
Expand All @@ -62,10 +66,10 @@ def _get_client(self, config: RepoConfig) -> Elasticsearch:
{
"host": online_store_config.host or "localhost",
"port": online_store_config.port or 9200,
"scheme": online_store_config.scheme or "http"
"scheme": online_store_config.scheme or "http",
}
],
basic_auth=(user, password)
basic_auth=(user, password),
)
return self._client

Expand All @@ -78,13 +82,13 @@ def _bulk_batch_actions(self, table: FeatureView, batch: List[Dict[str, Any]]):
}

def online_write_batch(
self,
config: RepoConfig,
table: FeatureView,
data: List[
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
],
progress: Optional[Callable[[int], Any]],
self,
config: RepoConfig,
table: FeatureView,
data: List[
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
],
progress: Optional[Callable[[int], Any]],
) -> None:
insert_values = []
for entity_key, values, timestamp, created_ts in data:
Expand All @@ -97,7 +101,9 @@ def online_write_batch(
if created_ts is not None:
created_ts = _to_naive_utc(created_ts)
for feature_name, value in values.items():
encoded_value = base64.b64encode(value.SerializeToString()).decode("utf-8")
encoded_value = base64.b64encode(value.SerializeToString()).decode(
"utf-8"
)
vector_val = json.loads(get_list_val_str(value))
insert_values.append(
{
Expand All @@ -112,16 +118,16 @@ def online_write_batch(

batch_size = config.online_config.write_batch_size
for i in range(0, len(insert_values), batch_size):
batch = insert_values[i: i + batch_size]
batch = insert_values[i : i + batch_size]
actions = self._bulk_batch_actions(table, batch)
helpers.bulk(self._get_client(config), actions)

def online_read(
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
if not requested_features:
body = {
Expand Down Expand Up @@ -167,16 +173,18 @@ def create_index(self, config: RepoConfig, table: FeatureView):
},
}
}
self._get_client(config).indices.create(index=table.name, mappings=index_mapping)
self._get_client(config).indices.create(
index=table.name, mappings=index_mapping
)

def update(
self,
config: RepoConfig,
tables_to_delete: Sequence[FeatureView],
tables_to_keep: Sequence[FeatureView],
entities_to_delete: Sequence[Entity],
entities_to_keep: Sequence[Entity],
partial: bool,
self,
config: RepoConfig,
tables_to_delete: Sequence[FeatureView],
tables_to_keep: Sequence[FeatureView],
entities_to_delete: Sequence[Entity],
entities_to_keep: Sequence[Entity],
partial: bool,
):
# implement the update method
for table in tables_to_delete:
Expand All @@ -185,10 +193,10 @@ def update(
self.create_index(config, table)

def teardown(
self,
config: RepoConfig,
tables: Sequence[FeatureView],
entities: Sequence[Entity],
self,
config: RepoConfig,
tables: Sequence[FeatureView],
entities: Sequence[Entity],
):
project = config.project
try:
Expand All @@ -199,14 +207,14 @@ def teardown(
raise

def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: str,
embedding: List[float],
top_k: int,
*args,
**kwargs,
self,
config: RepoConfig,
table: FeatureView,
requested_feature: str,
embedding: List[float],
top_k: int,
*args,
**kwargs,
) -> List[
Tuple[
Optional[datetime],
Expand Down