Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix agent mypy error #206

Merged
merged 6 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
self.stored_messages: List[ChatRecord]
self.init_messages()

def reset(self) -> List[ChatRecord]:
def reset(self):
r"""Resets the :obj:`ChatAgent` to its initial state and returns the
stored messages.

Expand All @@ -132,7 +132,6 @@ def reset(self) -> List[ChatRecord]:
"""
self.terminated = False
self.init_messages()
return self.stored_messages

def set_output_language(self, output_language: str) -> BaseMessage:
r"""Sets the output language for the system message. This method
Expand Down
2 changes: 1 addition & 1 deletion camel/agents/critic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def get_option(self, input_message: BaseMessage) -> str:
msg_content = input_message.content
i = 0
while i < self.retry_attempts:
critic_response = super().step(input_message)
critic_response = self.step(input_message)

if critic_response.msgs is None or len(critic_response.msgs) == 0:
raise RuntimeError("Got None critic messages.")
Expand Down
17 changes: 8 additions & 9 deletions camel/agents/task_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
self.task_specify_prompt = task_specify_prompt_template.format(
word_limit=word_limit)
else:
self.task_specify_prompt = task_specify_prompt
self.task_specify_prompt = TextPrompt(task_specify_prompt)

model_config = model_config or ChatGPTConfig(temperature=1.0)

Expand All @@ -75,15 +75,15 @@ def __init__(
super().__init__(system_message, model, model_config,
output_language=output_language)

def step(
def run(
self,
original_task_prompt: Union[str, TextPrompt],
task_prompt: Union[str, TextPrompt],
meta_dict: Optional[Dict[str, Any]] = None,
) -> TextPrompt:
r"""Specify the given task prompt by providing more details.

Args:
original_task_prompt (Union[str, TextPrompt]): The original task
task_prompt (Union[str, TextPrompt]): The original task
prompt.
meta_dict (Optional[Dict[str, Any]]): A dictionary containing
additional information to include in the prompt.
Expand All @@ -94,15 +94,15 @@ def step(
"""
self.reset()
self.task_specify_prompt = self.task_specify_prompt.format(
task=original_task_prompt)
task=task_prompt)

if meta_dict is not None:
self.task_specify_prompt = (self.task_specify_prompt.format(
**meta_dict))

task_msg = BaseMessage.make_user_message(
role_name="Task Specifier", content=self.task_specify_prompt)
specifier_response = super().step(task_msg)
specifier_response = self.step(task_msg)
if (specifier_response.msgs is None
or len(specifier_response.msgs) == 0):
raise RuntimeError("Task specification failed.")
Expand Down Expand Up @@ -151,7 +151,7 @@ def __init__(
super().__init__(system_message, model, model_config,
output_language=output_language)

def step(
def run(
self,
task_prompt: Union[str, TextPrompt],
) -> TextPrompt:
Expand All @@ -171,8 +171,7 @@ def step(

task_msg = BaseMessage.make_user_message(
role_name="Task Planner", content=self.task_planner_prompt)
# sub_tasks_msgs, terminated, _
task_response = super().step(task_msg)
task_response = self.step(task_msg)

if task_response.msgs is None:
raise RuntimeError("Got None Subtasks messages.")
Expand Down
7 changes: 4 additions & 3 deletions camel/societies/role_playing.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def init_specified_task_prompt(
output_language=output_language,
**(task_specify_agent_kwargs or {}),
)
self.specified_task_prompt = task_specify_agent.step(
self.specified_task_prompt = task_specify_agent.run(
self.task_prompt,
meta_dict=task_specify_meta_dict,
)
Expand All @@ -191,10 +191,11 @@ def init_planned_task_prompt(self,
output_language=output_language,
**(task_planner_agent_kwargs or {}),
)
self.planned_task_prompt = task_planner_agent.step(
self.task_prompt)
self.planned_task_prompt = task_planner_agent.run(self.task_prompt)
self.task_prompt = (f"{self.task_prompt}\n"
f"{self.planned_task_prompt}")
else:
self.planned_task_prompt = None

def get_sys_message_info(
self, assistant_role_name: str, user_role_name: str,
Expand Down
2 changes: 1 addition & 1 deletion examples/code/role_playing_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def generate_data(language_idx: int, language_name: str, domain_idx: int,
task_type=TaskType.CODE,
model_config=ChatGPTConfig(temperature=1.4),
)
specified_task_prompt = task_specify_agent.step(
specified_task_prompt = task_specify_agent.run(
original_task_prompt,
meta_dict=dict(domain=domain_name, language=language_name),
)
Expand Down
8 changes: 4 additions & 4 deletions test/agents/test_task_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_task_specify_ai_society_agent(model: Optional[ModelType]):
print(f"Original task prompt:\n{original_task_prompt}\n")
task_specify_agent = TaskSpecifyAgent(
model_config=ChatGPTConfig(temperature=1.0), model=model)
specified_task_prompt = task_specify_agent.step(
specified_task_prompt = task_specify_agent.run(
original_task_prompt, meta_dict=dict(assistant_role="Musician",
user_role="Student"))
assert ("{" and "}" not in task_specify_agent.task_specify_prompt)
Expand All @@ -47,7 +47,7 @@ def test_task_specify_code_agent(model: Optional[ModelType]):
model_config=ChatGPTConfig(temperature=1.0),
model=model,
)
specified_task_prompt = task_specify_agent.step(
specified_task_prompt = task_specify_agent.run(
original_task_prompt, meta_dict=dict(domain="Chemistry",
language="Python"))
assert ("{" and "}" not in task_specify_agent.task_specify_prompt)
Expand All @@ -63,11 +63,11 @@ def test_task_planner_agent(model: Optional[ModelType]):
model_config=ChatGPTConfig(temperature=1.0),
model=model,
)
specified_task_prompt = task_specify_agent.step(
specified_task_prompt = task_specify_agent.run(
original_task_prompt, meta_dict=dict(domain="Chemistry",
language="Python"))
print(f"Specified task prompt:\n{specified_task_prompt}\n")
task_planner_agent = TaskPlannerAgent(
model_config=ChatGPTConfig(temperature=1.0), model=model)
planned_task_prompt = task_planner_agent.step(specified_task_prompt)
planned_task_prompt = task_planner_agent.run(specified_task_prompt)
print(f"Planned task prompt:\n{planned_task_prompt}\n")