forked from All-Hands-AI/OpenHands
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add debug dir for prompts (All-Hands-AI#205)
* add debug dir for prompts * add indent to dumps * only wrap completion in debug mode * fix mypy
- Loading branch information
Showing
3 changed files
with
40 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -191,3 +191,4 @@ yarn-error.log* | |
# agent | ||
.envrc | ||
/workspace | ||
/debug |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,55 @@ | ||
import os | ||
import uuid | ||
|
||
from litellm import completion as litellm_completion | ||
from functools import partial | ||
import os | ||
|
||
DEFAULT_MODEL = os.getenv("LLM_MODEL", "gpt-4-0125-preview") | ||
DEFAULT_API_KEY = os.getenv("LLM_API_KEY") | ||
PROMPT_DEBUG_DIR = os.getenv("PROMPT_DEBUG_DIR", "") | ||
|
||
class LLM: | ||
def __init__(self, model=DEFAULT_MODEL, api_key=DEFAULT_API_KEY): | ||
def __init__(self, model=DEFAULT_MODEL, api_key=DEFAULT_API_KEY, debug_dir=PROMPT_DEBUG_DIR): | ||
self.model = model if model else DEFAULT_MODEL | ||
self.api_key = api_key if api_key else DEFAULT_API_KEY | ||
|
||
self._debug_dir = debug_dir | ||
self._debug_idx = 0 | ||
self._debug_id = uuid.uuid4().hex | ||
self._completion = partial(litellm_completion, model=self.model, api_key=self.api_key) | ||
|
||
if self._debug_dir: | ||
print(f"Logging prompts to {self._debug_dir}/{self._debug_id}") | ||
completion_unwrapped = self._completion | ||
def wrapper(*args, **kwargs): | ||
if "messages" in kwargs: | ||
messages = kwargs["messages"] | ||
else: | ||
messages = args[1] | ||
resp = completion_unwrapped(*args, **kwargs) | ||
message_back = resp['choices'][0]['message']['content'] | ||
self.write_debug(messages, message_back) | ||
return resp | ||
self._completion = wrapper # type: ignore | ||
|
||
@property | ||
def completion(self): | ||
""" | ||
Decorator for the litellm completion function. | ||
""" | ||
return self._completion | ||
|
||
def write_debug(self, messages, response): | ||
if not self._debug_dir: | ||
return | ||
dir = self._debug_dir + "/" + self._debug_id + "/" + str(self._debug_idx) | ||
os.makedirs(dir, exist_ok=True) | ||
prompt_out = "" | ||
for message in messages: | ||
prompt_out += "<" + message["role"] + ">\n" | ||
prompt_out += message["content"] + "\n\n" | ||
with open(f"{dir}/prompt.md", "w") as f: | ||
f.write(prompt_out) | ||
with open(f"{dir}/response.md", "w") as f: | ||
f.write(response) | ||
self._debug_idx += 1 | ||
|