Skip to content

Commit

Permalink
(enh) StuckDetector: fix+enhance syntax error loop detection (All-Han…
Browse files Browse the repository at this point in the history
…ds-AI#3628)

* fix StuckDetector and add more errors for detection

* more stringent error detection and more unit tests
  • Loading branch information
tobitege committed Aug 29, 2024
1 parent ae153aa commit a2d94c9
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 87 deletions.
125 changes: 100 additions & 25 deletions openhands/controller/stuck.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import cast

from openhands.controller.state.state import State
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.action import Action
Expand All @@ -16,6 +14,12 @@


class StuckDetector:
SYNTAX_ERROR_MESSAGES = [
'SyntaxError: unterminated string literal (detected at line',
'SyntaxError: invalid syntax. Perhaps you forgot a comma?',
'SyntaxError: incomplete input',
]

def __init__(self, state: State):
self.state = state

Expand Down Expand Up @@ -119,36 +123,107 @@ def _is_stuck_repeating_action_observation(self, last_actions, last_observations

def _is_stuck_repeating_action_error(self, last_actions, last_observations):
# scenario 2: same action, errors
# it takes 4 actions and 4 observations to detect a loop
# check if the last four actions are the same and result in errors
# it takes 3 actions and 3 observations to detect a loop
# check if the last three actions are the same and result in errors

# are the last four actions the same?
if len(last_actions) == 4 and all(
self._eq_no_pid(last_actions[0], action) for action in last_actions
):
# and the last four observations all errors?
if all(isinstance(obs, ErrorObservation) for obs in last_observations):
if len(last_actions) < 4 or len(last_observations) < 4:
return False

# are the last three actions the "same"?
if all(self._eq_no_pid(last_actions[0], action) for action in last_actions[:3]):
# and the last three observations are all errors?
if all(isinstance(obs, ErrorObservation) for obs in last_observations[:3]):
logger.warning('Action, ErrorObservation loop detected')
return True
# or, are the last four observations all IPythonRunCellObservation with SyntaxError?
# or, are the last three observations all IPythonRunCellObservation with SyntaxError?
elif all(
isinstance(obs, IPythonRunCellObservation) for obs in last_observations
) and all(
cast(IPythonRunCellObservation, obs)
.content[-100:]
.find('SyntaxError: unterminated string literal (detected at line')
!= -1
and len(
cast(IPythonRunCellObservation, obs).content.split(
isinstance(obs, IPythonRunCellObservation)
for obs in last_observations[:3]
):
warning = 'Action, IPythonRunCellObservation loop detected'
for error_message in self.SYNTAX_ERROR_MESSAGES:
if error_message.startswith(
'SyntaxError: unterminated string literal (detected at line'
)[-1]
):
if self._check_for_consistent_line_error(
last_observations[:3], error_message
):
logger.warning(warning)
return True
elif error_message in [
'SyntaxError: invalid syntax. Perhaps you forgot a comma?',
'SyntaxError: incomplete input',
]:
if self._check_for_consistent_invalid_syntax(
last_observations[:3], error_message
):
logger.warning(warning)
return True
return False

def _check_for_consistent_invalid_syntax(self, observations, error_message):
first_lines = []
valid_observations = []

for obs in observations:
content = obs.content
lines = content.strip().split('\n')

if len(lines) < 4:
return False

first_lines.append(lines[0]) # Store the first line of each observation

# Check last three lines
if lines[-2].startswith('[Jupyter current working directory:') and lines[
-1
].startswith('[Jupyter Python interpreter:'):
if error_message in lines[-3]:
valid_observations.append(obs)
break

# Check if:
# 1. All first lines are identical
# 2. We have exactly 3 valid observations
# 3. The error message line is identical in all valid observations
return (
len(set(first_lines)) == 1
and len(valid_observations) == 3
and len(
set(
obs.content.strip().split('\n')[:-2][-1]
for obs in valid_observations
)
< 10
for obs in last_observations
)
== 1
)

def _check_for_consistent_line_error(self, observations, error_message):
error_lines = []

for obs in observations:
content = obs.content
lines = content.strip().split('\n')

if len(lines) < 3:
return False

last_lines = lines[-3:]

# Check if the last two lines are our own
if not (
last_lines[-2].startswith('[Jupyter current working directory:')
and last_lines[-1].startswith('[Jupyter Python interpreter:')
):
logger.warning('Action, IPythonRunCellObservation loop detected')
return True
return False
return False

# Check for the error message in the 3rd-to-last line
if error_message in last_lines[-3]:
error_lines.append(last_lines[-3])

# Check if we found the error message in all 3 observations
# and the 3rd-to-last line is identical across all occurrences
return len(error_lines) == 3 and len(set(error_lines)) == 1

def _is_stuck_monologue(self, filtered_history):
# scenario 3: monologue
Expand Down
154 changes: 92 additions & 62 deletions tests/unit/test_is_stuck.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import random
from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -27,6 +28,9 @@ def collect_events(stream):

logging.basicConfig(level=logging.DEBUG)

jupyter_line_1 = '\n[Jupyter current working directory:'
jupyter_line_2 = '\n[Jupyter Python interpreter:'


@pytest.fixture
def temp_dir(tmp_path_factory: TempPathFactory) -> str:
Expand All @@ -46,11 +50,44 @@ class TestStuckDetector:
@pytest.fixture
def stuck_detector(self, event_stream):
state = State(inputs={}, max_iterations=50)
# state.history = ShortTermHistory()
state.history.set_event_stream(event_stream)

return StuckDetector(state)

def _impl_syntax_error_events(
self, event_stream: EventStream, error_message: str, random_line: bool
):
for _ in range(4):
ipython_action = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action, EventSource.AGENT)
extra_number = random.randint(88, 222) if random_line else '42'
extra_line = '\n' * random.randint(2, 3) if random_line else ''
ipython_observation = IPythonRunCellObservation(
content=f' Cell In[1], line {extra_number}\n'
'to_replace="""def largest(min_factor, max_factor):\n ^\n'
f'{error_message}{extra_line}' + jupyter_line_1 + jupyter_line_2,
code='print("hello',
)
print(ipython_observation.content)
ipython_observation._cause = ipython_action._id
event_stream.add_event(ipython_observation, EventSource.USER)

def _impl_unterminated_string_error_events(
self, event_stream: EventStream, random_line: bool
):
for _ in range(4):
ipython_action = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action, EventSource.AGENT)
line_number = str(random.randint(1, 10)) if random_line else '1'
ipython_observation = IPythonRunCellObservation(
content=f'print("hello\n ^\nSyntaxError: unterminated string literal (detected at line {line_number})'
+ jupyter_line_1
+ jupyter_line_2,
code='print("hello',
)
ipython_observation._cause = ipython_action._id
event_stream.add_event(ipython_observation, EventSource.USER)

def test_history_too_short(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
Expand Down Expand Up @@ -202,81 +239,75 @@ def test_is_stuck_repeating_action_error(
'Action, ErrorObservation loop detected'
)

def test_is_stuck_ipython_syntax_error(
def test_is_stuck_invalid_syntax_error(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
ipython_action_1 = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action_1, EventSource.AGENT)
ipython_observation_1 = IPythonRunCellObservation(
content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)',
code='print("hello',
self._impl_syntax_error_events(
event_stream,
error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
random_line=False,
)
ipython_observation_1._cause = ipython_action_1._id
event_stream.add_event(ipython_observation_1, EventSource.USER)

ipython_action_2 = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action_2, EventSource.AGENT)
ipython_observation_2 = IPythonRunCellObservation(
content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)',
code='print("hello',
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is True

def test_is_not_stuck_invalid_syntax_error(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
self._impl_syntax_error_events(
event_stream,
error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
random_line=True,
)
ipython_observation_2._cause = ipython_action_2._id
event_stream.add_event(ipython_observation_2, EventSource.USER)

ipython_action_3 = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action_3, EventSource.AGENT)
ipython_observation_3 = IPythonRunCellObservation(
content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 3)',
code='print("hello',
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False

def test_is_stuck_incomplete_input_error(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
self._impl_syntax_error_events(
event_stream,
error_message='SyntaxError: incomplete input',
random_line=False,
)
ipython_observation_3._cause = ipython_action_3._id
event_stream.add_event(ipython_observation_3, EventSource.USER)

ipython_action_4 = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action_4, EventSource.AGENT)
ipython_observation_4 = IPythonRunCellObservation(
content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 2)',
code='print("hello',
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is True

def test_is_not_stuck_incomplete_input_error(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
self._impl_syntax_error_events(
event_stream,
error_message='SyntaxError: incomplete input',
random_line=True,
)
ipython_observation_4._cause = ipython_action_4._id
event_stream.add_event(ipython_observation_4, EventSource.USER)

# stuck_detector.state.history.set_event_stream(event_stream)
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False

last_observations = [
ipython_observation_1,
ipython_observation_2,
ipython_observation_3,
ipython_observation_4,
]
for observation in last_observations:
has_string = (
observation.content[-100:].find(
'SyntaxError: unterminated string literal (detected at line'
)
!= -1
)
assert has_string

string_is_last = (
len(
observation.content.split(
'SyntaxError: unterminated string literal (detected at line'
)[-1]
)
< 10
)
assert string_is_last
def test_is_not_stuck_ipython_unterminated_string_error(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
self._impl_unterminated_string_error_events(event_stream, random_line=True)

with patch('logging.Logger.warning') as mock_warning:
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False

def test_is_stuck_ipython_unterminated_string_error(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
self._impl_unterminated_string_error_events(event_stream, random_line=False)

with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is True
mock_warning.assert_called_once_with(
'Action, IPythonRunCellObservation loop detected'
)

def test_is_stuck_ipython_syntax_error_not_at_end(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
# this test is to make sure we don't get false positives
# since the "at line x" is changing in between!
ipython_action_1 = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action_1, EventSource.AGENT)
ipython_observation_1 = IPythonRunCellObservation(
Expand Down Expand Up @@ -480,8 +511,6 @@ def test_is_stuck_monologue(self, stuck_detector, event_stream):
event_stream.add_event(message_action_6, EventSource.AGENT)
message_action_6._source = EventSource.AGENT

# stuck_detector.state.history.set_event_stream(event_stream)

assert stuck_detector.is_stuck()

# Add an observation event between the repeated message actions
Expand All @@ -502,7 +531,8 @@ def test_is_stuck_monologue(self, stuck_detector, event_stream):
event_stream.add_event(message_action_8, EventSource.AGENT)
message_action_8._source = EventSource.AGENT

assert not stuck_detector.is_stuck()
with patch('logging.Logger.warning'):
assert not stuck_detector.is_stuck()


class TestAgentController:
Expand Down

0 comments on commit a2d94c9

Please sign in to comment.