Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pagination and partial name search to List Jobs APIs #581

Merged
merged 16 commits into from
May 21, 2024
Prev Previous commit
Next Next commit
Update SDK to use paginated API
  • Loading branch information
Krithika Sundararajan committed May 14, 2024
commit 91a68c6fa36a1edf55ed709d0eb8f61d2e0a6254
46 changes: 44 additions & 2 deletions api/api/prediction_job_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ func (c *PredictionJobController) List(r *http.Request, vars map[string]string,
return Ok(jobs)
}

// ListInPage method lists all prediction jobs of a model and version ID, in the given page.
func (c *PredictionJobController) ListInPage(r *http.Request, vars map[string]string, _ interface{}) *Response {
// ListInPage method lists all prediction jobs of a model and version ID, with pagination.
func (c *PredictionJobController) ListByPage(r *http.Request, vars map[string]string, _ interface{}) *Response {
ctx := r.Context()

modelID, _ := models.ParseID(vars["model_id"])
Expand Down Expand Up @@ -264,3 +264,45 @@ func (c *PredictionJobController) ListAllInProject(r *http.Request, vars map[str

return Ok(jobs)
}

// ListAllInProject lists all prediction jobs of a project, with pagination
func (c *PredictionJobController) ListAllInProjectByPage(r *http.Request, vars map[string]string, body interface{}) *Response {
ctx := r.Context()

var query service.ListPredictionJobQuery
err := decoder.Decode(&query, r.URL.Query())
if err != nil {
return BadRequest(fmt.Sprintf("Bad query %s", r.URL.Query()))
}

projectID, _ := models.ParseID(vars["project_id"])
page, pageErr := strconv.Atoi(vars["page"])
pageSize, pageSizeErr := strconv.Atoi(vars["page_size"])

project, err := c.ProjectsService.GetByID(ctx, int32(projectID))
if err != nil {
return NotFound(fmt.Sprintf("Project not found: %v", err))
}

// We will append page and pageSize to the query if they are set.
paginationOpts := pagination.Options{}
if pageErr == nil {
pageInt32 := int32(page)
paginationOpts.Page = &pageInt32
}
if pageSizeErr == nil {
pageSizeInt32 := int32(pageSize)
paginationOpts.PageSize = &pageSizeInt32
}
query.Pagination = paginationOpts

jobs, paging, err := c.PredictionJobService.ListPredictionJobs(ctx, project, &query)
if err != nil {
return InternalServerError(fmt.Sprintf("Error listing prediction jobs: %v", err))
}

return Ok(ListJobsPaginatedResponse{
Results: jobs,
Paging: paging,
})
}
4 changes: 2 additions & 2 deletions api/api/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,9 @@ func NewRouter(appCtx AppContext) (*mux.Router, error) {

// Prediction Job API
{http.MethodGet, "/projects/{project_id:[0-9]+}/jobs", nil, predictionJobController.ListAllInProject, "ListAllPredictionJobInProject"},
//{http.MethodGet, "/projects/{project_id:[0-9]+}/jobs-by-page", nil, predictionJobController.ListAllInProjectInPage, "ListAllInProjectInPage"},
{http.MethodGet, "/projects/{project_id:[0-9]+}/jobs-by-page", nil, predictionJobController.ListAllInProjectByPage, "ListAllInProjectInPage"},
{http.MethodGet, "/models/{model_id:[0-9]+}/versions/{version_id:[0-9]+}/jobs", nil, predictionJobController.List, "ListPredictionJob"},
{http.MethodGet, "/models/{model_id:[0-9]+}/versions/{version_id:[0-9]+}/jobs-by-page", nil, predictionJobController.ListInPage, "ListPredictionJobInPage"},
{http.MethodGet, "/models/{model_id:[0-9]+}/versions/{version_id:[0-9]+}/jobs-by-page", nil, predictionJobController.ListByPage, "ListPredictionJobInPage"},
{http.MethodGet, "/models/{model_id:[0-9]+}/versions/{version_id:[0-9]+}/jobs/{job_id:[0-9]+}", nil, predictionJobController.Get, "GetPredictionJob"},
{http.MethodPut, "/models/{model_id:[0-9]+}/versions/{version_id:[0-9]+}/jobs/{job_id:[0-9]+}/stop", nil, predictionJobController.Stop, "StopPredictionJob"},
{http.MethodPost, "/models/{model_id:[0-9]+}/versions/{version_id:[0-9]+}/jobs", models.PredictionJob{}, predictionJobController.Create, "CreatePredictionJob"},
Expand Down
6 changes: 1 addition & 5 deletions api/client/api_prediction_jobs.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 0 additions & 15 deletions python/sdk/client/api/prediction_jobs_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def models_model_id_versions_version_id_jobs_by_page_get(
self,
model_id: StrictInt,
version_id: StrictInt,
project_id: StrictInt,
page: Optional[StrictInt] = None,
page_size: Annotated[Optional[StrictInt], Field(description="Number of items on each page. It defaults to 50.")] = None,
_request_timeout: Union[
Expand All @@ -79,8 +78,6 @@ def models_model_id_versions_version_id_jobs_by_page_get(
:type model_id: int
:param version_id: (required)
:type version_id: int
:param project_id: (required)
:type project_id: int
:param page:
:type page: int
:param page_size: Number of items on each page. It defaults to 50.
Expand Down Expand Up @@ -110,7 +107,6 @@ def models_model_id_versions_version_id_jobs_by_page_get(
_param = self._models_model_id_versions_version_id_jobs_by_page_get_serialize(
model_id=model_id,
version_id=version_id,
project_id=project_id,
page=page,
page_size=page_size,
_request_auth=_request_auth,
Expand Down Expand Up @@ -140,7 +136,6 @@ def models_model_id_versions_version_id_jobs_by_page_get_with_http_info(
self,
model_id: StrictInt,
version_id: StrictInt,
project_id: StrictInt,
page: Optional[StrictInt] = None,
page_size: Annotated[Optional[StrictInt], Field(description="Number of items on each page. It defaults to 50.")] = None,
_request_timeout: Union[
Expand All @@ -163,8 +158,6 @@ def models_model_id_versions_version_id_jobs_by_page_get_with_http_info(
:type model_id: int
:param version_id: (required)
:type version_id: int
:param project_id: (required)
:type project_id: int
:param page:
:type page: int
:param page_size: Number of items on each page. It defaults to 50.
Expand Down Expand Up @@ -194,7 +187,6 @@ def models_model_id_versions_version_id_jobs_by_page_get_with_http_info(
_param = self._models_model_id_versions_version_id_jobs_by_page_get_serialize(
model_id=model_id,
version_id=version_id,
project_id=project_id,
page=page,
page_size=page_size,
_request_auth=_request_auth,
Expand Down Expand Up @@ -224,7 +216,6 @@ def models_model_id_versions_version_id_jobs_by_page_get_without_preload_content
self,
model_id: StrictInt,
version_id: StrictInt,
project_id: StrictInt,
page: Optional[StrictInt] = None,
page_size: Annotated[Optional[StrictInt], Field(description="Number of items on each page. It defaults to 50.")] = None,
_request_timeout: Union[
Expand All @@ -247,8 +238,6 @@ def models_model_id_versions_version_id_jobs_by_page_get_without_preload_content
:type model_id: int
:param version_id: (required)
:type version_id: int
:param project_id: (required)
:type project_id: int
:param page:
:type page: int
:param page_size: Number of items on each page. It defaults to 50.
Expand Down Expand Up @@ -278,7 +267,6 @@ def models_model_id_versions_version_id_jobs_by_page_get_without_preload_content
_param = self._models_model_id_versions_version_id_jobs_by_page_get_serialize(
model_id=model_id,
version_id=version_id,
project_id=project_id,
page=page,
page_size=page_size,
_request_auth=_request_auth,
Expand All @@ -303,7 +291,6 @@ def _models_model_id_versions_version_id_jobs_by_page_get_serialize(
self,
model_id,
version_id,
project_id,
page,
page_size,
_request_auth,
Expand All @@ -330,8 +317,6 @@ def _models_model_id_versions_version_id_jobs_by_page_get_serialize(
_path_params['model_id'] = model_id
if version_id is not None:
_path_params['version_id'] = version_id
if project_id is not None:
_path_params['project_id'] = project_id
# process the query parameters
if page is not None:

Expand Down
16 changes: 14 additions & 2 deletions python/sdk/merlin/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,12 +1559,24 @@ def list_prediction_job(self) -> List[PredictionJob]:
:return: list of prediction jobs
"""
job_client = client.PredictionJobsApi(self._api_client)
res = job_client.models_model_id_versions_version_id_jobs_get(

res = job_client.models_model_id_versions_version_id_jobs_by_page_get(
model_id=self.model.id, version_id=self.id
)
jobs = []
for j in res:
for j in res.results:
jobs.append(PredictionJob(j, self._api_client))

# Paginated response. Parse the rest of the pages.
total_pages = res.paging.pages
page = 2
while page < total_pages:
res = job_client.models_model_id_versions_version_id_jobs_by_page_get(
model_id=self.model.id, version_id=self.id, page=page
)
for j in res.results:
jobs.append(PredictionJob(j, self._api_client))

return jobs

def start_server(
Expand Down
31 changes: 31 additions & 0 deletions python/sdk/pyfunc.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2020 The Merlin Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

ARG BASE_IMAGE=ghcr.io/caraml-dev/merlin/merlin-pyfunc-base:0.41.0
FROM ${BASE_IMAGE}

# Download and install user model dependencies
ARG MODEL_DEPENDENCIES_URL
COPY ${MODEL_DEPENDENCIES_URL} conda.yaml

ARG MERLIN_DEP_CONSTRAINT
RUN process_conda_env.sh conda.yaml "merlin-pyfunc-server" "${MERLIN_DEP_CONSTRAINT}"
RUN conda env create --name merlin-model --file conda.yaml

# Download and dry-run user model artifacts and code
ARG MODEL_ARTIFACTS_URL
COPY ${MODEL_ARTIFACTS_URL} model
RUN /bin/bash -c ". activate merlin-model && merlin-pyfunc-server --model_dir model --dry_run"

CMD ["run.sh"]
20 changes: 20 additions & 0 deletions python/sdk/standard.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2020 The Merlin Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

ARG BASE_IMAGE
ARG MODEL_PATH

FROM ${BASE_IMAGE}

COPY ${MODEL_PATH} /mnt/models
13 changes: 9 additions & 4 deletions python/sdk/test/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,10 +941,15 @@ def test_undeploy_default_env(self, version):
def test_list_prediction_job(self, version):
responses.add(
method="GET",
url="/v1/models/1/versions/1/jobs",
body=json.dumps(
[job_1.to_dict(), job_2.to_dict()], default=serialize_datetime
),
url="/v1/models/1/versions/1/jobs-by-page",
body=json.dumps({
"results": [job_1.to_dict(), job_2.to_dict()],
"paging": {
"page": 1,
"pages": 1,
"total": 2,
},
}, default=serialize_datetime),
status=200,
content_type="application/json",
)
Expand Down
5 changes: 0 additions & 5 deletions swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1404,11 +1404,6 @@ paths:
required: true
schema:
type: integer
- name: project_id
in: path
required: true
schema:
type: integer
- name: page
in: query
schema:
Expand Down