forked from All-Hands-AI/OpenHands
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add tests for agent controller (All-Hands-AI#3357)
* Add tests for agent controller * Remove dead code * Remove dead code
- Loading branch information
Showing
1 changed file
with
194 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
import asyncio | ||
from unittest.mock import AsyncMock, MagicMock | ||
|
||
import pytest | ||
|
||
from opendevin.controller.agent import Agent | ||
from opendevin.controller.agent_controller import AgentController | ||
from opendevin.controller.state.state import TrafficControlState | ||
from opendevin.core.exceptions import LLMMalformedActionError | ||
from opendevin.core.schema import AgentState | ||
from opendevin.events import EventStream | ||
from opendevin.events.action import ChangeAgentStateAction, MessageAction | ||
|
||
|
||
@pytest.fixture | ||
def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str: | ||
return str(tmp_path_factory.mktemp('test_event_stream')) | ||
|
||
|
||
@pytest.fixture(scope='function') | ||
def event_loop(): | ||
loop = asyncio.get_event_loop_policy().new_event_loop() | ||
yield loop | ||
loop.close() | ||
|
||
|
||
@pytest.fixture | ||
def mock_agent(): | ||
return MagicMock(spec=Agent) | ||
|
||
|
||
@pytest.fixture | ||
def mock_event_stream(): | ||
return MagicMock(spec=EventStream) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_set_agent_state(mock_agent, mock_event_stream): | ||
controller = AgentController( | ||
agent=mock_agent, | ||
event_stream=mock_event_stream, | ||
max_iterations=10, | ||
sid='test', | ||
confirmation_mode=False, | ||
headless_mode=True, | ||
) | ||
await controller.set_agent_state_to(AgentState.RUNNING) | ||
assert controller.get_agent_state() == AgentState.RUNNING | ||
|
||
await controller.set_agent_state_to(AgentState.PAUSED) | ||
assert controller.get_agent_state() == AgentState.PAUSED | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_on_event_message_action(mock_agent, mock_event_stream): | ||
controller = AgentController( | ||
agent=mock_agent, | ||
event_stream=mock_event_stream, | ||
max_iterations=10, | ||
sid='test', | ||
confirmation_mode=False, | ||
headless_mode=True, | ||
) | ||
controller.state.agent_state = AgentState.RUNNING | ||
message_action = MessageAction(content='Test message') | ||
await controller.on_event(message_action) | ||
assert controller.get_agent_state() == AgentState.RUNNING | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream): | ||
controller = AgentController( | ||
agent=mock_agent, | ||
event_stream=mock_event_stream, | ||
max_iterations=10, | ||
sid='test', | ||
confirmation_mode=False, | ||
headless_mode=True, | ||
) | ||
controller.state.agent_state = AgentState.RUNNING | ||
change_state_action = ChangeAgentStateAction(agent_state=AgentState.PAUSED) | ||
await controller.on_event(change_state_action) | ||
assert controller.get_agent_state() == AgentState.PAUSED | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_report_error(mock_agent, mock_event_stream): | ||
controller = AgentController( | ||
agent=mock_agent, | ||
event_stream=mock_event_stream, | ||
max_iterations=10, | ||
sid='test', | ||
confirmation_mode=False, | ||
headless_mode=True, | ||
) | ||
error_message = 'Test error' | ||
await controller.report_error(error_message) | ||
assert controller.state.last_error == error_message | ||
controller.event_stream.add_event.assert_called_once() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_step_with_exception(mock_agent, mock_event_stream): | ||
controller = AgentController( | ||
agent=mock_agent, | ||
event_stream=mock_event_stream, | ||
max_iterations=10, | ||
sid='test', | ||
confirmation_mode=False, | ||
headless_mode=True, | ||
) | ||
controller.state.agent_state = AgentState.RUNNING | ||
controller.report_error = AsyncMock() | ||
controller.agent.step.side_effect = LLMMalformedActionError('Malformed action') | ||
await controller._step() | ||
|
||
# Verify that report_error was called with the correct error message | ||
controller.report_error.assert_called_once_with('Malformed action') | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_step_max_iterations(mock_agent, mock_event_stream): | ||
controller = AgentController( | ||
agent=mock_agent, | ||
event_stream=mock_event_stream, | ||
max_iterations=10, | ||
sid='test', | ||
confirmation_mode=False, | ||
headless_mode=False, | ||
) | ||
controller.state.agent_state = AgentState.RUNNING | ||
controller.state.iteration = 10 | ||
assert controller.state.traffic_control_state == TrafficControlState.NORMAL | ||
await controller._step() | ||
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING | ||
assert controller.state.agent_state == AgentState.PAUSED | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_step_max_iterations_headless(mock_agent, mock_event_stream): | ||
controller = AgentController( | ||
agent=mock_agent, | ||
event_stream=mock_event_stream, | ||
max_iterations=10, | ||
sid='test', | ||
confirmation_mode=False, | ||
headless_mode=True, | ||
) | ||
controller.state.agent_state = AgentState.RUNNING | ||
controller.state.iteration = 10 | ||
assert controller.state.traffic_control_state == TrafficControlState.NORMAL | ||
await controller._step() | ||
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING | ||
# In headless mode, throttling results in an error | ||
assert controller.state.agent_state == AgentState.ERROR | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_step_max_budget(mock_agent, mock_event_stream): | ||
controller = AgentController( | ||
agent=mock_agent, | ||
event_stream=mock_event_stream, | ||
max_iterations=10, | ||
max_budget_per_task=10, | ||
sid='test', | ||
confirmation_mode=False, | ||
headless_mode=False, | ||
) | ||
controller.state.agent_state = AgentState.RUNNING | ||
controller.state.metrics.accumulated_cost = 10.1 | ||
assert controller.state.traffic_control_state == TrafficControlState.NORMAL | ||
await controller._step() | ||
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING | ||
assert controller.state.agent_state == AgentState.PAUSED | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_step_max_budget_headless(mock_agent, mock_event_stream): | ||
controller = AgentController( | ||
agent=mock_agent, | ||
event_stream=mock_event_stream, | ||
max_iterations=10, | ||
max_budget_per_task=10, | ||
sid='test', | ||
confirmation_mode=False, | ||
headless_mode=True, | ||
) | ||
controller.state.agent_state = AgentState.RUNNING | ||
controller.state.metrics.accumulated_cost = 10.1 | ||
assert controller.state.traffic_control_state == TrafficControlState.NORMAL | ||
await controller._step() | ||
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING | ||
# In headless mode, throttling results in an error | ||
assert controller.state.agent_state == AgentState.ERROR |