From def648ed3ff16065888107ef3d9913c03df23144 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 8 Jun 2024 19:03:45 -0700 Subject: [PATCH] fix(proxy_server.py): allow passing in a list of team members allows batch adding members to a team by passing in a list. fixes concurrency issue caused by calling team/member_add in parallel --- litellm/proxy/_types.py | 17 +++- litellm/proxy/management_helpers/utils.py | 63 +++++++++++++ litellm/proxy/proxy_server.py | 110 +++++++++------------- litellm/proxy/schema.prisma | 2 +- schema.prisma | 2 +- tests/test_team.py | 26 +++-- 6 files changed, 143 insertions(+), 77 deletions(-) create mode 100644 litellm/proxy/management_helpers/utils.py diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 6409773e38a4..255ab43d6995 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -757,9 +757,24 @@ class GlobalEndUsersSpend(LiteLLMBase): class TeamMemberAddRequest(LiteLLMBase): team_id: str - member: Member + member: Union[List[Member], Member] max_budget_in_team: Optional[float] = None # Users max budget within the team + def __init__(self, **data): + member_data = data.get("member") + if isinstance(member_data, list): + # If member is a list of dictionaries, convert each dictionary to a Member object + members = [Member(**item) for item in member_data] + # Replace member_data with the list of Member objects + data["member"] = members + elif isinstance(member_data, dict): + # If member is a dictionary, convert it to a single Member object + member = Member(**member_data) + # Replace member_data with the single Member object + data["member"] = member + # Call the superclass __init__ method to initialize the object + super().__init__(**data) + class TeamMemberDeleteRequest(LiteLLMBase): team_id: str diff --git a/litellm/proxy/management_helpers/utils.py b/litellm/proxy/management_helpers/utils.py new file mode 100644 index 000000000000..6c035d3ef802 --- /dev/null +++ b/litellm/proxy/management_helpers/utils.py @@ -0,0 +1,63 @@ +# What is this? +## Helper utils for the management endpoints (keys/users/teams) + +from litellm.proxy._types import LiteLLM_TeamTable, Member, UserAPIKeyAuth +from litellm.proxy.utils import PrismaClient +import uuid +from typing import Optional + + +async def add_new_member( + new_member: Member, + max_budget_in_team: Optional[float], + prisma_client: PrismaClient, + team_id: str, + user_api_key_dict: UserAPIKeyAuth, + litellm_proxy_admin_name: str, +): + """ + Add a new member to a team + + - add team id to user table + - add team member w/ budget to team member table + """ + ## ADD TEAM ID, to USER TABLE IF NEW ## + if new_member.user_id is not None: + await prisma_client.db.litellm_usertable.update( + where={"user_id": new_member.user_id}, + data={"teams": {"push": [team_id]}}, + ) + elif new_member.user_email is not None: + user_data = {"user_id": str(uuid.uuid4()), "user_email": new_member.user_email} + ## user email is not unique acc. to prisma schema -> future improvement + ### for now: check if it exists in db, if not - insert it + existing_user_row = await prisma_client.get_data( + key_val={"user_email": new_member.user_email}, + table_name="user", + query_type="find_all", + ) + if existing_user_row is None or ( + isinstance(existing_user_row, list) and len(existing_user_row) == 0 + ): + + await prisma_client.insert_data(data=user_data, table_name="user") + + # Check if trying to set a budget for team member + if max_budget_in_team is not None and new_member.user_id is not None: + # create a new budget item for this member + response = await prisma_client.db.litellm_budgettable.create( + data={ + "max_budget": max_budget_in_team, + "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + } + ) + + _budget_id = response.budget_id + await prisma_client.db.litellm_teammembership.create( + data={ + "team_id": team_id, + "user_id": new_member.user_id, + "budget_id": _budget_id, + } + ) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 133124db8e9b..73b4ac547c4b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -90,6 +90,7 @@ def generate_feedback_box(): HttpxBinaryResponseContent, ) from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request +from litellm.proxy.management_helpers.utils import add_new_member from litellm.proxy.utils import ( PrismaClient, DBClient, @@ -10159,10 +10160,12 @@ async def team_member_add( raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) if data.member is None: - raise HTTPException(status_code=400, detail={"error": "No member passed in"}) + raise HTTPException( + status_code=400, detail={"error": "No member/members passed in"} + ) - existing_team_row = await prisma_client.get_data( # type: ignore - team_id=data.team_id, table_name="team", query_type="find_unique" + existing_team_row = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": data.team_id} ) if existing_team_row is None: raise HTTPException( @@ -10172,75 +10175,52 @@ async def team_member_add( }, ) - new_member = data.member + complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump()) - existing_team_row.members_with_roles.append(new_member) + if isinstance(data.member, Member): + # add to team db + new_member = data.member - complete_team_data = LiteLLM_TeamTable( - **_get_pydantic_json_dict(existing_team_row), - ) + complete_team_data.members_with_roles.append(new_member) - team_row = await prisma_client.update_data( - update_key_values=complete_team_data.json(exclude_none=True), - data=complete_team_data.json(exclude_none=True), - table_name="team", - team_id=data.team_id, - ) + elif isinstance(data.member, List): + # add to team db + new_members = data.member - ## ADD USER, IF NEW ## - user_data = { # type: ignore - "teams": [team_row["team_id"]], - "models": team_row["data"].models, - } - if new_member.user_id is not None: - user_data["user_id"] = new_member.user_id # type: ignore - await prisma_client.update_data( - user_id=new_member.user_id, - data=user_data, - update_key_values_custom_query={ - "teams": { - "push": [team_row["team_id"]], - } - }, - table_name="user", - ) - elif new_member.user_email is not None: - user_data["user_id"] = str(uuid.uuid4()) - user_data["user_email"] = new_member.user_email - ## user email is not unique acc. to prisma schema -> future improvement - ### for now: check if it exists in db, if not - insert it - existing_user_row = await prisma_client.get_data( - key_val={"user_email": new_member.user_email}, - table_name="user", - query_type="find_all", - ) - if existing_user_row is None or ( - isinstance(existing_user_row, list) and len(existing_user_row) == 0 - ): - - await prisma_client.insert_data(data=user_data, table_name="user") + complete_team_data.members_with_roles.extend(new_members) - # Check if trying to set a budget for team member - if data.max_budget_in_team is not None and new_member.user_id is not None: - # create a new budget item for this member - response = await prisma_client.db.litellm_budgettable.create( - data={ - "max_budget": data.max_budget_in_team, - "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, - "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, - } - ) + # ADD MEMBER TO TEAM + _db_team_members = [ + m.model_dump() for m in complete_team_data.members_with_roles + ] + updated_team = await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, + data={"members_with_roles": json.dumps(_db_team_members)}, # type: ignore + ) - _budget_id = response.budget_id - await prisma_client.db.litellm_teammembership.create( - data={ - "team_id": data.team_id, - "user_id": new_member.user_id, - "budget_id": _budget_id, - } - ) + if isinstance(data.member, Member): + await add_new_member( + new_member=data.member, + max_budget_in_team=data.max_budget_in_team, + prisma_client=prisma_client, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + team_id=data.team_id, + ) + elif isinstance(data.member, List): + tasks: List = [] + for m in data.member: + await add_new_member( + new_member=m, + max_budget_in_team=data.max_budget_in_team, + prisma_client=prisma_client, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + team_id=data.team_id, + ) + await asyncio.gather(*tasks) - return team_row + return updated_team @router.post( diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 8843761327d5..fbf535b47c84 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -91,7 +91,7 @@ model LiteLLM_TeamTable { updated_at DateTime @default(now()) @updatedAt @map("updated_at") model_spend Json @default("{}") model_max_budget Json @default("{}") - model_id Int? @unique + model_id Int? @unique // id for LiteLLM_ModelTable -> stores team-level model aliases litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id]) litellm_model_table LiteLLM_ModelTable? @relation(fields: [model_id], references: [id]) } diff --git a/schema.prisma b/schema.prisma index 7cc688ee8eab..cbe14c8a3bf7 100644 --- a/schema.prisma +++ b/schema.prisma @@ -91,7 +91,7 @@ model LiteLLM_TeamTable { updated_at DateTime @default(now()) @updatedAt @map("updated_at") model_spend Json @default("{}") model_max_budget Json @default("{}") - model_id Int? @unique + model_id Int? @unique // id for LiteLLM_ModelTable -> stores team-level model aliases litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id]) litellm_model_table LiteLLM_ModelTable? @relation(fields: [model_id], references: [id]) } diff --git a/tests/test_team.py b/tests/test_team.py index 3f7ed71b52b5..467767be0ecb 100644 --- a/tests/test_team.py +++ b/tests/test_team.py @@ -49,7 +49,7 @@ async def new_user( async def add_member( - session, i, team_id, user_id=None, user_email=None, max_budget=None + session, i, team_id, user_id=None, user_email=None, max_budget=None, members=None ): url = "http://0.0.0.0:4000/team/member_add" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} @@ -58,10 +58,13 @@ async def add_member( data["member"]["user_email"] = user_email elif user_id is not None: data["member"]["user_id"] = user_id + elif members is not None: + data["member"] = members if max_budget is not None: data["max_budget_in_team"] = max_budget + print("sent data: {}".format(data)) async with session.post(url, headers=headers, json=data) as response: status = response.status response_text = await response.text() @@ -339,7 +342,7 @@ async def test_team_info(): async def test_team_update_sc_2(): """ - Create team - - Add 1 user (doesn't exist in db) + - Add 3 users (doesn't exist in db) - Change team alias - Check if it works - Assert team object unchanged besides team alias @@ -353,15 +356,20 @@ async def test_team_update_sc_2(): {"role": "admin", "user_id": admin_user}, ] team_data = await new_team(session=session, i=0, member_list=member_list) - ## Create new normal user - new_normal_user = f"krrish_{uuid.uuid4()}@berri.ai" + ## Create 10 normal users + members = [ + {"role": "user", "user_id": f"krrish_{uuid.uuid4()}@berri.ai"} + for _ in range(10) + ] await add_member( - session=session, - i=0, - team_id=team_data["team_id"], - user_id=None, - user_email=new_normal_user, + session=session, i=0, team_id=team_data["team_id"], members=members ) + ## ASSERT TEAM SIZE + team_info = await get_team_info( + session=session, get_team=team_data["team_id"], call_key="sk-1234" + ) + + assert len(team_info["team_info"]["members_with_roles"]) == 12 ## CHANGE TEAM ALIAS