Skip to content

Commit

Permalink
Fix mypy errors in step functions of EmbodiedAgent, CriticAgent
Browse files Browse the repository at this point in the history
… and `Human` (#192)
  • Loading branch information
lightaime committed Jul 3, 2023
1 parent aea5b86 commit a77218c
Show file tree
Hide file tree
Showing 13 changed files with 77 additions and 59 deletions.
3 changes: 2 additions & 1 deletion camel/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from .base import BaseAgent
from .chat_agent import ChatAgent
from .chat_agent import ChatAgent, ChatAgentResponse
from .task_agent import TaskPlannerAgent, TaskSpecifyAgent
from .critic_agent import CriticAgent
from .tool_agents.base import BaseToolAgent
Expand All @@ -22,6 +22,7 @@
__all__ = [
'BaseAgent',
'ChatAgent',
'ChatAgentResponse',
'TaskSpecifyAgent',
'TaskPlannerAgent',
'CriticAgent',
Expand Down
5 changes: 3 additions & 2 deletions camel/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from abc import ABC, abstractmethod
from typing import Any


class BaseAgent(ABC):
r"""An abstract base class for all CAMEL agents."""

@abstractmethod
def reset(self) -> None:
def reset(self, *args: Any, **kwargs: Any) -> Any:
r"""Resets the agent to its initial state."""
pass

@abstractmethod
def step(self) -> None:
def step(self, *args: Any, **kwargs: Any) -> Any:
r"""Performs a single step of the agent."""
pass
14 changes: 7 additions & 7 deletions camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(
self.role_name: str = system_message.role_name
self.role_type: RoleType = system_message.role_type
self.output_language: Optional[str] = output_language
if output_language is not None:
if self.output_language is not None:
self.set_output_language(self.output_language)

self.model: ModelType = (model if model is not None else
Expand Down Expand Up @@ -320,32 +320,32 @@ def handle_stream_response(
tuple: A tuple of list of output `ChatMessage`, list of
finish reasons, usage dictionary, and response id.
"""
content_dict = defaultdict(lambda: "")
finish_reasons = defaultdict(lambda: "")
content_dict: defaultdict = defaultdict(lambda: "")
finish_reasons_dict: defaultdict = defaultdict(lambda: "")
output_messages: List[BaseMessage] = []
response_id: str = ""
# All choices in one response share one role
role: str = ""
for chunk in response:
response_id = chunk["id"]
for choice in chunk["choices"]:
index = choice["index"]
delta = choice["delta"]
index: int = choice["index"]
delta: Dict = choice["delta"]
if len(delta) != 0:
# When response has not been stopped
# Notice that only the first chunk has the "role"
role = delta.get("role", role)
delta_content = delta.get("content", "")
content_dict[index] += delta_content
else:
finish_reasons[index] = choice["finish_reason"]
finish_reasons_dict[index] = choice["finish_reason"]
chat_message = BaseMessage(role_name=self.role_name,
role_type=self.role_type,
meta_dict=dict(),
content=content_dict[index])
output_messages.append(chat_message)
finish_reasons = [
finish_reasons[i] for i in range(len(finish_reasons))
finish_reasons_dict[i] for i in range(len(finish_reasons_dict))
]
usage_dict = self.get_usage_dict(output_messages, prompt_tokens)
return output_messages, finish_reasons, usage_dict, response_id
Expand Down
29 changes: 15 additions & 14 deletions camel/agents/critic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import copy
import random
import warnings
from typing import Any, Dict, Optional, Sequence

from colorama import Fore

from camel.agents import ChatAgent
from camel.agents import ChatAgent, ChatAgentResponse
from camel.messages import BaseMessage
from camel.typing import ModelType
from camel.utils import get_first_int, print_text_animated
Expand Down Expand Up @@ -141,33 +140,35 @@ def parse_critic(self, critic_msg: BaseMessage) -> Optional[str]:
choice = str(get_first_int(critic_msg.content))
return choice

def step(self, messages: Sequence[BaseMessage]) -> BaseMessage:
def reduce_step(
self,
input_messages: Sequence[BaseMessage],
) -> ChatAgentResponse:
r"""Performs one step of the conversation by flattening options to the
critic, getting the option, and parsing the choice.
Args:
messages (Sequence[BaseMessage]): A list of BaseMessage objects.
Returns:
BaseMessage: A `BaseMessage` object representing the critic's
choice.
ChatAgentResponse: A `ChatAgentResponse` object includes the
critic's choice.
"""
meta_chat_message = BaseMessage(
role_name=messages[0].role_name,
role_type=messages[0].role_type,
meta_dict=messages[0].meta_dict,
role_name=input_messages[0].role_name,
role_type=input_messages[0].role_type,
meta_dict=input_messages[0].meta_dict,
content="",
)

flatten_options = self.flatten_options(messages)
flatten_options = self.flatten_options(input_messages)
if self.verbose:
print_text_animated(self.logger_color +
f"\x1b[3m{flatten_options}\x1b[0m\n")
input_msg = copy.deepcopy(meta_chat_message)
input_msg.content = flatten_options
input_msg = meta_chat_message.create_new_instance(flatten_options)

option = self.get_option(input_msg)
output_msg = copy.deepcopy(meta_chat_message)
output_msg.content = option
output_msg = meta_chat_message.create_new_instance(option)

return output_msg
# TODO: The return `info` can be improved.
return ChatAgentResponse([output_msg], terminated=False, info={})
19 changes: 12 additions & 7 deletions camel/agents/embodied_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, List, Optional

from colorama import Fore

from camel.agents import BaseToolAgent, ChatAgent, HuggingFaceToolAgent
from camel.agents import (
BaseToolAgent,
ChatAgent,
ChatAgentResponse,
HuggingFaceToolAgent,
)
from camel.messages import BaseMessage
from camel.typing import ModelType
from camel.utils import print_text_animated
Expand Down Expand Up @@ -80,16 +85,16 @@ def get_action_space_prompt(self) -> str:
def step(
self,
input_message: BaseMessage,
) -> Tuple[BaseMessage, bool, Dict[str, Any]]:
) -> ChatAgentResponse:
r"""Performs a step in the conversation.
Args:
input_message (BaseMessage): The input message.
Returns:
Tuple[BaseMessage, bool, Dict[str, Any]]: A tuple
containing the output messages, termination status, and
additional information.
ChatAgentResponse: A struct containing the output messages,
a boolean indicating whether the chat session has terminated,
and information about the chat session.
"""
response = super().step(input_message)

Expand Down Expand Up @@ -128,4 +133,4 @@ def step(
f"\n> Embodied Actions:\n{content}")
message = BaseMessage(input_message.role_name, input_message.role_type,
input_message.meta_dict, content)
return message, response.terminated, response.info
return ChatAgentResponse([message], response.terminated, response.info)
4 changes: 2 additions & 2 deletions camel/agents/task_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
model_config: Optional[Any] = None,
task_specify_prompt: Optional[Union[str, TextPrompt]] = None,
word_limit: int = DEFAULT_WORD_LIMIT,
output_language: str = None,
output_language: Optional[str] = None,
) -> None:

if task_specify_prompt is None:
Expand Down Expand Up @@ -135,7 +135,7 @@ def __init__(
self,
model: Optional[ModelType] = None,
model_config: Any = None,
output_language: str = None,
output_language: Optional[str] = None,
) -> None:

self.task_planner_prompt = TextPrompt(
Expand Down
26 changes: 14 additions & 12 deletions camel/human.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from colorama import Fore

from camel.agents import ChatAgentResponse
from camel.messages import BaseMessage
from camel.utils import print_text_animated

Expand Down Expand Up @@ -86,36 +87,35 @@ def get_input(self) -> str:

return human_input

def parse_input(self, human_input: str,
meta_chat_message: BaseMessage) -> BaseMessage:
def parse_input(self, human_input: str) -> str:
r"""Parses the user's input and returns a `BaseMessage` object.
Args:
human_input (str): The user's input.
meta_chat_message (BaseMessage): A `BaseMessage` object.
Returns:
BaseMessage: A `BaseMessage` object.
content: A `str` object representing the user's input.
"""
if self.options_dict[human_input] == self.input_button:
meta_chat_message.content = input(self.logger_color +
"Please enter your message: ")
return meta_chat_message
content = input(self.logger_color + "Please enter your message: ")
elif self.options_dict[human_input] == self.kill_button:
exit(self.logger_color + f"Killed by {self.name}.")
else:
meta_chat_message.content = self.options_dict[human_input]
return meta_chat_message
content = self.options_dict[human_input]

def step(self, messages: Sequence[BaseMessage]) -> BaseMessage:
return content

def reduce_step(self,
messages: Sequence[BaseMessage]) -> ChatAgentResponse:
r"""Performs one step of the conversation by displaying options to the
user, getting their input, and parsing their choice.
Args:
messages (Sequence[BaseMessage]): A list of BaseMessage objects.
Returns:
BaseMessage: A `BaseMessage` object representing the user's choice.
ChatAgentResponse: A `ChatAgentResponse` object representing the
user's choice.
"""
meta_chat_message = BaseMessage(
role_name=messages[0].role_name,
Expand All @@ -125,4 +125,6 @@ def step(self, messages: Sequence[BaseMessage]) -> BaseMessage:
)
self.display_options(messages)
human_input = self.get_input()
return self.parse_input(human_input, meta_chat_message)
content = self.parse_input(human_input)
message = meta_chat_message.create_new_instance(content)
return ChatAgentResponse([message], terminated=False, info={})
5 changes: 3 additions & 2 deletions camel/societies/role_playing.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
sys_msg_generator_kwargs: Optional[Dict] = None,
extend_sys_msg_meta_dicts: Optional[List[Dict]] = None,
extend_task_specify_meta_dict: Optional[Dict] = None,
output_language: str = None,
output_language: Optional[str] = None,
) -> None:
self.with_task_specify = with_task_specify
self.with_task_planner = with_task_planner
Expand Down Expand Up @@ -250,7 +250,8 @@ def reduce_message_options(
raise ValueError("Got than one message to process. "
f"Num of messages: {len(messages)}.")
elif self.with_critic_in_the_loop and self.critic is not None:
processed_msg = self.critic.step(messages)
critic_response = self.critic.reduce_step(messages)
processed_msg = critic_response.msg
else:
processed_msg = messages[0]

Expand Down
4 changes: 2 additions & 2 deletions examples/embodiment/hugging_face_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def main():
"caption the image content, "
"save the images by species name."),
)
output_message, _, _ = embodied_agent.step(user_msg)
print(output_message.content)
response = embodied_agent.step(user_msg)
print(response.msg.content)


if __name__ == "__main__":
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,9 @@ markers = [

[tool.coverage.report]
include_namespace_packages = true

[[tool.mypy.overrides]]
module = [
"transformers.tools",
]
ignore_missing_imports = true
7 changes: 4 additions & 3 deletions test/agents/test_critic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_parse_critic(critic_agent: CriticAgent):


@pytest.mark.model_backend
def test_step(critic_agent: CriticAgent):
def test_reduce_step(critic_agent: CriticAgent):
messages = [
BaseMessage(
role_name="user",
Expand All @@ -112,5 +112,6 @@ def test_step(critic_agent: CriticAgent):
),
]

assert (critic_agent.step(messages)
== messages[0]) or (critic_agent.step(messages) == messages[1])
critic_response = critic_agent.reduce_step(messages)
assert (critic_response.msg == messages[0]) or (critic_response.msg
== messages[1])
8 changes: 4 additions & 4 deletions test/agents/test_embodied_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_step():
role_name=role_name,
content="Draw all the Camelidae species.",
)
output_message, terminated, info = embodied_agent.step(user_msg)
assert isinstance(output_message, BaseMessage)
assert not terminated
assert isinstance(info, dict)
response = embodied_agent.step(user_msg)
assert isinstance(response.msg, BaseMessage)
assert not response.terminated
assert isinstance(response.info, dict)
6 changes: 3 additions & 3 deletions test/test_human.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_get_input(monkeypatch):
assert human.get_input() == str(1)


def test_step(monkeypatch):
def test_reduce_step(monkeypatch):
human = Human()
msgs = [
BaseMessage.make_assistant_message(role_name="assistant",
Expand All @@ -49,5 +49,5 @@ def test_step(monkeypatch):
]

monkeypatch.setattr('builtins.input', lambda _: str(1))
msg = human.step(msgs)
assert msg.content == "Hello"
human_response = human.reduce_step(msgs)
assert human_response.msg.content == "Hello"

0 comments on commit a77218c

Please sign in to comment.