Skip to content

Commit

Permalink
Add tests for agent controller (All-Hands-AI#3357)
Browse files Browse the repository at this point in the history
* Add tests for agent controller

* Remove dead code

* Remove dead code
  • Loading branch information
neubig committed Aug 14, 2024
1 parent b6243bb commit 50b1256
Showing 1 changed file with 194 additions and 0 deletions.
194 changes: 194 additions & 0 deletions tests/unit/test_agent_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock

import pytest

from opendevin.controller.agent import Agent
from opendevin.controller.agent_controller import AgentController
from opendevin.controller.state.state import TrafficControlState
from opendevin.core.exceptions import LLMMalformedActionError
from opendevin.core.schema import AgentState
from opendevin.events import EventStream
from opendevin.events.action import ChangeAgentStateAction, MessageAction


@pytest.fixture
def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str:
return str(tmp_path_factory.mktemp('test_event_stream'))


@pytest.fixture(scope='function')
def event_loop():
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()


@pytest.fixture
def mock_agent():
return MagicMock(spec=Agent)


@pytest.fixture
def mock_event_stream():
return MagicMock(spec=EventStream)


@pytest.mark.asyncio
async def test_set_agent_state(mock_agent, mock_event_stream):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
await controller.set_agent_state_to(AgentState.RUNNING)
assert controller.get_agent_state() == AgentState.RUNNING

await controller.set_agent_state_to(AgentState.PAUSED)
assert controller.get_agent_state() == AgentState.PAUSED


@pytest.mark.asyncio
async def test_on_event_message_action(mock_agent, mock_event_stream):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
controller.state.agent_state = AgentState.RUNNING
message_action = MessageAction(content='Test message')
await controller.on_event(message_action)
assert controller.get_agent_state() == AgentState.RUNNING


@pytest.mark.asyncio
async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
controller.state.agent_state = AgentState.RUNNING
change_state_action = ChangeAgentStateAction(agent_state=AgentState.PAUSED)
await controller.on_event(change_state_action)
assert controller.get_agent_state() == AgentState.PAUSED


@pytest.mark.asyncio
async def test_report_error(mock_agent, mock_event_stream):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
error_message = 'Test error'
await controller.report_error(error_message)
assert controller.state.last_error == error_message
controller.event_stream.add_event.assert_called_once()


@pytest.mark.asyncio
async def test_step_with_exception(mock_agent, mock_event_stream):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
controller.state.agent_state = AgentState.RUNNING
controller.report_error = AsyncMock()
controller.agent.step.side_effect = LLMMalformedActionError('Malformed action')
await controller._step()

# Verify that report_error was called with the correct error message
controller.report_error.assert_called_once_with('Malformed action')


@pytest.mark.asyncio
async def test_step_max_iterations(mock_agent, mock_event_stream):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=False,
)
controller.state.agent_state = AgentState.RUNNING
controller.state.iteration = 10
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
await controller._step()
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
assert controller.state.agent_state == AgentState.PAUSED


@pytest.mark.asyncio
async def test_step_max_iterations_headless(mock_agent, mock_event_stream):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
controller.state.agent_state = AgentState.RUNNING
controller.state.iteration = 10
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
await controller._step()
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
# In headless mode, throttling results in an error
assert controller.state.agent_state == AgentState.ERROR


@pytest.mark.asyncio
async def test_step_max_budget(mock_agent, mock_event_stream):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
max_budget_per_task=10,
sid='test',
confirmation_mode=False,
headless_mode=False,
)
controller.state.agent_state = AgentState.RUNNING
controller.state.metrics.accumulated_cost = 10.1
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
await controller._step()
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
assert controller.state.agent_state == AgentState.PAUSED


@pytest.mark.asyncio
async def test_step_max_budget_headless(mock_agent, mock_event_stream):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
max_budget_per_task=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
controller.state.agent_state = AgentState.RUNNING
controller.state.metrics.accumulated_cost = 10.1
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
await controller._step()
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
# In headless mode, throttling results in an error
assert controller.state.agent_state == AgentState.ERROR

0 comments on commit 50b1256

Please sign in to comment.