Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(enh) StuckDetector: fix+enhance syntax error loop detection #3628

Merged
merged 4 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
more stringent error detection and more unit tests
  • Loading branch information
tobitege committed Aug 29, 2024
commit 9a1d8fd32bac9d3bf404e86d14a9c2f464cf98e2
128 changes: 92 additions & 36 deletions openhands/controller/stuck.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class StuckDetector:
'SyntaxError: unterminated string literal (detected at line',
'SyntaxError: invalid syntax. Perhaps you forgot a comma?',
'SyntaxError: incomplete input',
"E999 SyntaxError: unmatched ')'",
]

def __init__(self, state: State):
Expand Down Expand Up @@ -127,47 +126,104 @@ def _is_stuck_repeating_action_error(self, last_actions, last_observations):
# it takes 3 actions and 3 observations to detect a loop
# check if the last three actions are the same and result in errors

if len(last_actions) >= 3 and len(last_observations) >= 3:
# are the last three actions the same?
if all(
self._eq_no_pid(last_actions[0], action) for action in last_actions[:3]
if len(last_actions) < 4 or len(last_observations) < 4:
return False

# are the last three actions the "same"?
if all(self._eq_no_pid(last_actions[0], action) for action in last_actions[:3]):
# and the last three observations are all errors?
if all(isinstance(obs, ErrorObservation) for obs in last_observations[:3]):
logger.warning('Action, ErrorObservation loop detected')
return True
# or, are the last three observations all IPythonRunCellObservation with SyntaxError?
elif all(
isinstance(obs, IPythonRunCellObservation)
for obs in last_observations[:3]
):
# and the last three observations all errors?
if all(
isinstance(obs, ErrorObservation) for obs in last_observations[:3]
):
logger.warning('Action, ErrorObservation loop detected')
return True
# or, are the last three observations all IPythonRunCellObservation with SyntaxError?
elif all(
isinstance(obs, IPythonRunCellObservation)
for obs in last_observations[:3]
):
for error_message in self.SYNTAX_ERROR_MESSAGES:
if all(
self._check_for_error_message(obs, error_message)
for obs in last_observations[:3]
warning = 'Action, IPythonRunCellObservation loop detected'
for error_message in self.SYNTAX_ERROR_MESSAGES:
if error_message.startswith(
'SyntaxError: unterminated string literal (detected at line'
):
if self._check_for_consistent_line_error(
last_observations[:3], error_message
):
logger.warning(
'Action, IPythonRunCellObservation loop detected'
)
logger.warning(warning)
return True
elif error_message in [
'SyntaxError: invalid syntax. Perhaps you forgot a comma?',
'SyntaxError: incomplete input',
]:
if self._check_for_consistent_invalid_syntax(
last_observations[:3], error_message
):
logger.warning(warning)
return True
return False

def _check_for_error_message(
self, obs: IPythonRunCellObservation, error_message: str
) -> bool:
content = obs.content
# 4 lines, because the last 2 lines can be about interpreter and prompts
# since the introduction of the EventStreamRuntime
lines = content.strip().split('\n')
last_four_lines = lines[-4:]

for line in last_four_lines:
if error_message in line:
return True
def _check_for_consistent_invalid_syntax(self, observations, error_message):
first_lines = []
valid_observations = []

for obs in observations:
content = obs.content
lines = content.strip().split('\n')

if len(lines) < 4:
return False

first_lines.append(lines[0]) # Store the first line of each observation

# Check last three lines
if lines[-2].startswith('[Jupyter current working directory:') and lines[
-1
].startswith('[Jupyter Python interpreter:'):
if error_message in lines[-3]:
valid_observations.append(obs)
break

# Check if:
# 1. All first lines are identical
# 2. We have exactly 3 valid observations
# 3. The error message line is identical in all valid observations
return (
len(set(first_lines)) == 1
and len(valid_observations) == 3
and len(
set(
obs.content.strip().split('\n')[:-2][-1]
for obs in valid_observations
)
)
== 1
)

return False
def _check_for_consistent_line_error(self, observations, error_message):
error_lines = []

for obs in observations:
content = obs.content
lines = content.strip().split('\n')

if len(lines) < 3:
return False

last_lines = lines[-3:]

# Check if the last two lines are our own
if not (
last_lines[-2].startswith('[Jupyter current working directory:')
and last_lines[-1].startswith('[Jupyter Python interpreter:')
):
return False

# Check for the error message in the 3rd-to-last line
if error_message in last_lines[-3]:
error_lines.append(last_lines[-3])

# Check if we found the error message in all 3 observations
# and the 3rd-to-last line is identical across all occurrences
return len(error_lines) == 3 and len(set(error_lines)) == 1

def _is_stuck_monologue(self, filtered_history):
# scenario 3: monologue
Expand Down
160 changes: 95 additions & 65 deletions tests/unit/test_is_stuck.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import random
from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -27,6 +28,9 @@ def collect_events(stream):

logging.basicConfig(level=logging.DEBUG)

jupyter_line_1 = '\n[Jupyter current working directory:'
jupyter_line_2 = '\n[Jupyter Python interpreter:'


@pytest.fixture
def temp_dir(tmp_path_factory: TempPathFactory) -> str:
Expand All @@ -46,11 +50,44 @@ class TestStuckDetector:
@pytest.fixture
def stuck_detector(self, event_stream):
state = State(inputs={}, max_iterations=50)
# state.history = ShortTermHistory()
state.history.set_event_stream(event_stream)

return StuckDetector(state)

def _impl_syntax_error_events(
self, event_stream: EventStream, error_message: str, random_line: bool
):
for _ in range(4):
ipython_action = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action, EventSource.AGENT)
extra_number = random.randint(88, 222) if random_line else '42'
extra_line = '\n' * random.randint(2, 3) if random_line else ''
ipython_observation = IPythonRunCellObservation(
content=f' Cell In[1], line {extra_number}\n'
'to_replace="""def largest(min_factor, max_factor):\n ^\n'
f'{error_message}{extra_line}' + jupyter_line_1 + jupyter_line_2,
code='print("hello',
)
print(ipython_observation.content)
ipython_observation._cause = ipython_action._id
event_stream.add_event(ipython_observation, EventSource.USER)

def _impl_unterminated_string_error_events(
self, event_stream: EventStream, random_line: bool
):
for _ in range(4):
ipython_action = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action, EventSource.AGENT)
line_number = str(random.randint(1, 10)) if random_line else '1'
ipython_observation = IPythonRunCellObservation(
content=f'print("hello\n ^\nSyntaxError: unterminated string literal (detected at line {line_number})'
+ jupyter_line_1
+ jupyter_line_2,
code='print("hello',
)
ipython_observation._cause = ipython_action._id
event_stream.add_event(ipython_observation, EventSource.USER)

def test_history_too_short(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
Expand Down Expand Up @@ -202,79 +239,75 @@ def test_is_stuck_repeating_action_error(
'Action, ErrorObservation loop detected'
)

def test_is_stuck_ipython_syntax_error(
def test_is_stuck_invalid_syntax_error(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
ipython_action_1 = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action_1, EventSource.AGENT)
ipython_observation_1 = IPythonRunCellObservation(
content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)',
code='print("hello',
self._impl_syntax_error_events(
event_stream,
error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
random_line=False,
)
ipython_observation_1._cause = ipython_action_1._id
event_stream.add_event(ipython_observation_1, EventSource.USER)

ipython_action_2 = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action_2, EventSource.AGENT)
ipython_observation_2 = IPythonRunCellObservation(
content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)',
code='print("hello',
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is True

def test_is_not_stuck_invalid_syntax_error(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
self._impl_syntax_error_events(
event_stream,
error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
random_line=True,
)
ipython_observation_2._cause = ipython_action_2._id
event_stream.add_event(ipython_observation_2, EventSource.USER)

ipython_action_3 = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action_3, EventSource.AGENT)
ipython_observation_3 = IPythonRunCellObservation(
content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 3)',
code='print("hello',
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False

def test_is_stuck_incomplete_input_error(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
self._impl_syntax_error_events(
event_stream,
error_message='SyntaxError: incomplete input',
random_line=False,
)
ipython_observation_3._cause = ipython_action_3._id
event_stream.add_event(ipython_observation_3, EventSource.USER)

ipython_action_4 = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action_4, EventSource.AGENT)
ipython_observation_4 = IPythonRunCellObservation(
content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 2)',
code='print("hello',
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is True

def test_is_not_stuck_incomplete_input_error(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
self._impl_syntax_error_events(
event_stream,
error_message='SyntaxError: incomplete input',
random_line=True,
)
ipython_observation_4._cause = ipython_action_4._id
event_stream.add_event(ipython_observation_4, EventSource.USER)

last_observations = [
ipython_observation_1,
ipython_observation_2,
ipython_observation_3,
ipython_observation_4,
]
for observation in last_observations:
has_string = (
observation.content[-100:].find(
'SyntaxError: unterminated string literal (detected at line'
)
!= -1
)
assert has_string

string_is_last = (
len(
observation.content.split(
'SyntaxError: unterminated string literal (detected at line'
)[-1]
)
< 10
)
assert string_is_last
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False

with patch('logging.Logger.warning') as mock_warning:
def test_is_not_stuck_ipython_unterminated_string_error(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
self._impl_unterminated_string_error_events(event_stream, random_line=True)

with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False

def test_is_stuck_ipython_unterminated_string_error(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
self._impl_unterminated_string_error_events(event_stream, random_line=False)

with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is True
mock_warning.assert_called_once_with(
'Action, IPythonRunCellObservation loop detected'
)

def test_is_stuck_ipython_syntax_error_not_at_end(
self, stuck_detector: StuckDetector, event_stream: EventStream
):
# this test is to make sure we don't get false positives
# since the "at line x" is changing in between!
Copy link
Collaborator

@enyst enyst Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔
FWIW, I wonder if requiring the error message to be identical isn't too restrictive now, hah (I'll see myself out lol). Please correct me if I'm not seeing right (looking at the diff), around line 190 in stuck.py it looks like the error too has to be entirely the same. It's an interesting question. Because I recall the stuck llm tries to make minor variations when it sees a syntax error, like it might try to run stuff starting with quotes or backquotes. It's still a syntax error, but it can be on another word or another line.

(FWIW I tend to think that this test was succeeding because the "syntax error" string had to be near the end of the line or something like that, not requiring an identical line number. I could be wrong though.)

Maybe that's all okay, since we're reducing from 4 to 3. 🤔 Three but identical might work well enough. The llm did seem to have a favorite thing to try lol, while variations of the code to run (e.g. when it tried running with quotes) were more rare than the main thing it wanted. If this works in the scenario you found, we can go with it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error doesn't have to be exactly the same, but the first line in those errors is like this:
Cell In[1], line 42
So similar to the other syntax error, where a change of the line number "breaks" the loop, I added the first line to be the same in 3 observations. The first line does not contain the actual error, but the location.
Then it is up to the bottom 3 lines to match, basically (error type and 2 lines of our prompt).

I'll try to re-run a couple of instances from the Aider bench, from where I had that error collection from originally. :)

ipython_action_1 = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action_1, EventSource.AGENT)
ipython_observation_1 = IPythonRunCellObservation(
Expand Down Expand Up @@ -312,10 +345,8 @@ def test_is_stuck_ipython_syntax_error_not_at_end(
event_stream.add_event(ipython_observation_4, EventSource.USER)

with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is True
mock_warning.assert_called_once_with(
'Action, IPythonRunCellObservation loop detected'
)
assert stuck_detector.is_stuck() is False
mock_warning.assert_not_called()

def test_is_stuck_repeating_action_observation_pattern(
self, stuck_detector: StuckDetector, event_stream: EventStream
Expand Down Expand Up @@ -480,8 +511,6 @@ def test_is_stuck_monologue(self, stuck_detector, event_stream):
event_stream.add_event(message_action_6, EventSource.AGENT)
message_action_6._source = EventSource.AGENT

# stuck_detector.state.history.set_event_stream(event_stream)

assert stuck_detector.is_stuck()

# Add an observation event between the repeated message actions
Expand All @@ -502,7 +531,8 @@ def test_is_stuck_monologue(self, stuck_detector, event_stream):
event_stream.add_event(message_action_8, EventSource.AGENT)
message_action_8._source = EventSource.AGENT

assert not stuck_detector.is_stuck()
with patch('logging.Logger.warning'):
assert not stuck_detector.is_stuck()


class TestAgentController:
Expand Down