Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add online_read_async for dynamodb #4244

Merged
merged 16 commits into from
Jun 5, 2024
Prev Previous commit
Next Next commit
fix resource vs client response handling
Signed-off-by: robhowley <rhowley@seatgeek.com>
  • Loading branch information
robhowley committed May 31, 2024
commit a1d1adaa67be30e35611df577fd3cdc390826ebe
23 changes: 18 additions & 5 deletions sdk/python/feast/infra/online_stores/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,9 @@ async def online_read_async(
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
table_name = _get_table_name(online_config, config, table)

def to_tbl_resp(raw_client_response):
return {"entity_id": raw_client_response["entity_id"]["S"]}

async with self._get_aiodynamodb_client(online_config.region) as client:
while True:
batch = list(itertools.islice(entity_ids_iter, batch_size))
Expand All @@ -300,7 +303,7 @@ async def online_read_async(
RequestItems=batch_entity_ids,
)
batch_result = self._process_batch_get_response(
table_name, response, entity_ids, batch
table_name, response, entity_ids, batch, to_tbl_response=to_tbl_resp
)
result.extend(batch_result)
return result
Expand All @@ -325,13 +328,19 @@ def _get_dynamodb_resource(self, region: str, endpoint_url: Optional[str] = None
)
return self._dynamodb_resource

def _sort_dynamodb_response(self, responses: list, order: list) -> Any:
def _sort_dynamodb_response(
self,
responses: list,
order: list,
to_tbl_response: Callable = lambda raw_dict: raw_dict,
) -> Any:
"""DynamoDB Batch Get Item doesn't return items in a particular order."""
# Assign an index to order
order_with_index = {value: idx for idx, value in enumerate(order)}
# Sort table responses by index
table_responses_ordered: Any = [
(order_with_index[tbl_res["entity_id"]], tbl_res) for tbl_res in responses
(order_with_index[tbl_res["entity_id"]], tbl_res)
for tbl_res in map(to_tbl_response, responses)
]
table_responses_ordered = sorted(
table_responses_ordered, key=lambda tup: tup[0]
Expand Down Expand Up @@ -368,13 +377,17 @@ def _write_batch_non_duplicates(
if progress:
progress(1)

def _process_batch_get_response(self, table_name, response, entity_ids, batch):
def _process_batch_get_response(
self, table_name, response, entity_ids, batch, **sort_kwargs
):
response = response.get("Responses")
table_responses = response.get(table_name)

batch_result = []
if table_responses:
table_responses = self._sort_dynamodb_response(table_responses, entity_ids)
table_responses = self._sort_dynamodb_response(
table_responses, entity_ids, **sort_kwargs
)
entity_idx = 0
for tbl_res in table_responses:
entity_id = tbl_res["entity_id"]
Expand Down
Loading