Skip to content

Commit

Permalink
gh-118893: Evaluate all statements in the new REPL separately
Browse files Browse the repository at this point in the history
  • Loading branch information
pablogsal committed May 21, 2024
1 parent c4f9823 commit 40df02e
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 9 deletions.
33 changes: 29 additions & 4 deletions Lib/_pyrepl/simple_interact.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import linecache
import sys
import code
import ast
from types import ModuleType

from .readline import _get_reader, multiline_input
Expand Down Expand Up @@ -74,9 +75,36 @@ def __init__(
super().__init__(locals=locals, filename=filename, local_exit=local_exit) # type: ignore[call-arg]
self.can_colorize = _colorize.can_colorize()

def showsyntaxerror(self, filename=None):
super().showsyntaxerror(colorize=self.can_colorize)

def showtraceback(self):
super().showtraceback(colorize=self.can_colorize)

def runsource(self, source, filename="<input>", symbol="single"):
try:
tree = ast.parse(source)
except (OverflowError, SyntaxError, ValueError):
self.showsyntaxerror(filename)
return False
if tree.body:
*_, last_stmt = tree.body
for stmt in tree.body:
wrapper = ast.Interactive if stmt is last_stmt else ast.Module
the_symbol = symbol if stmt is last_stmt else "exec"
item = wrapper([stmt])
try:
code = compile(item, filename, the_symbol)
except (OverflowError, ValueError):
self.showsyntaxerror(filename)
return False

if code is None:
return True

self.runcode(code)
return False


def run_multiline_interactive_console(
mainmodule: ModuleType | None= None, future_flags: int = 0
Expand Down Expand Up @@ -144,10 +172,7 @@ def more_lines(unicodetext: str) -> bool:

input_name = f"<python-input-{input_n}>"
linecache._register_code(input_name, statement, "<stdin>") # type: ignore[attr-defined]
symbol = "single" if not contains_pasted_code else "exec"
more = console.push(_strip_final_indent(statement), filename=input_name, _symbol=symbol) # type: ignore[call-arg]
if contains_pasted_code and more:
more = console.push(_strip_final_indent(statement), filename=input_name, _symbol="single") # type: ignore[call-arg]
more = console.push(_strip_final_indent(statement), filename=input_name, _symbol="single") # type: ignore[call-arg]
assert not more
input_n += 1
except KeyboardInterrupt:
Expand Down
5 changes: 3 additions & 2 deletions Lib/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def runcode(self, code):
except:
self.showtraceback()

def showsyntaxerror(self, filename=None):
def showsyntaxerror(self, filename=None, **kwargs):
"""Display the syntax error that just occurred.
This doesn't display a stack trace because there isn't one.
Expand All @@ -106,6 +106,7 @@ def showsyntaxerror(self, filename=None):
The output is written by self.write(), below.
"""
colorize = kwargs.pop('colorize', False)
type, value, tb = sys.exc_info()
sys.last_exc = value
sys.last_type = type
Expand All @@ -123,7 +124,7 @@ def showsyntaxerror(self, filename=None):
value = SyntaxError(msg, (filename, lineno, offset, line))
sys.last_exc = sys.last_value = value
if sys.excepthook is sys.__excepthook__:
lines = traceback.format_exception_only(type, value)
lines = traceback.format_exception_only(type, value, colorize=colorize)
self.write(''.join(lines))
else:
# If someone has set sys.excepthook, we let that take precedence
Expand Down
87 changes: 86 additions & 1 deletion Lib/test/test_pyrepl.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import itertools
import os
import rlcompleter
import sys
import tempfile
import unittest
from code import InteractiveConsole
from functools import partial
from unittest import TestCase
from unittest.mock import MagicMock, patch
from textwrap import dedent
import contextlib
import io

from test.support import requires
from test.support.import_helper import import_module
Expand Down Expand Up @@ -1002,5 +1004,88 @@ def test_up_arrow_after_ctrl_r(self):
self.assert_screen_equals(reader, "")


class TestSimpleInteract(unittest.TestCase):
def test_multiple_statements(self):
namespace = {}
code = dedent("""\
class A:
def foo(self):
pass
class B:
def bar(self):
pass
a = 1
a
""")
console = InteractiveColoredConsole(namespace, filename="<stdin>")
with (
patch.object(InteractiveColoredConsole, "showsyntaxerror") as showsyntaxerror,
patch.object(InteractiveColoredConsole, "runsource", wraps=console.runsource) as runsource,
):
more = console.push(code, filename="<stdin>", _symbol="single") # type: ignore[call-arg]
self.assertFalse(more)
showsyntaxerror.assert_not_called()


def test_multiple_statements_output(self):
namespace = {}
code = dedent("""\
b = 1
b
a = 1
a
""")
console = InteractiveColoredConsole(namespace, filename="<stdin>")
f = io.StringIO()
with contextlib.redirect_stdout(f):
more = console.push(code, filename="<stdin>", _symbol="single") # type: ignore[call-arg]
self.assertFalse(more)
self.assertEqual(f.getvalue(), "1\n")

def test_empty(self):
namespace = {}
code = ""
console = InteractiveColoredConsole(namespace, filename="<stdin>")
f = io.StringIO()
with contextlib.redirect_stdout(f):
more = console.push(code, filename="<stdin>", _symbol="single") # type: ignore[call-arg]
self.assertFalse(more)
self.assertEqual(f.getvalue(), "")

def test_runsource_compiles_and_runs_code(self):
console = InteractiveColoredConsole()
source = "print('Hello, world!')"
with patch.object(console, "runcode") as mock_runcode:
console.runsource(source)
mock_runcode.assert_called_once()

def test_runsource_returns_false_for_successful_compilation(self):
console = InteractiveColoredConsole()
source = "print('Hello, world!')"
result = console.runsource(source)
self.assertFalse(result)

def test_runsource_returns_false_for_failed_compilation(self):
console = InteractiveColoredConsole()
source = "print('Hello, world!'"
result = console.runsource(source)
self.assertFalse(result)

def test_runsource_shows_syntax_error_for_failed_compilation(self):
console = InteractiveColoredConsole()
source = "print('Hello, world!'"
with patch.object(console, "showsyntaxerror") as mock_showsyntaxerror:
console.runsource(source)
mock_showsyntaxerror.assert_called_once()


if __name__ == '__main__':
unittest.main()


if __name__ == '__main__':
unittest.main()
5 changes: 3 additions & 2 deletions Lib/traceback.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def format_exception(exc, /, value=_sentinel, tb=_sentinel, limit=None, \
return list(te.format(chain=chain, colorize=colorize))


def format_exception_only(exc, /, value=_sentinel, *, show_group=False):
def format_exception_only(exc, /, value=_sentinel, *, show_group=False, **kwargs):
"""Format the exception part of a traceback.
The return value is a list of strings, each ending in a newline.
Expand All @@ -170,10 +170,11 @@ def format_exception_only(exc, /, value=_sentinel, *, show_group=False):
:exc:`BaseExceptionGroup`, the nested exceptions are included as
well, recursively, with indentation relative to their nesting depth.
"""
colorize = kwargs.get("colorize", False)
if value is _sentinel:
value = exc
te = TracebackException(type(value), value, None, compact=True)
return list(te.format_exception_only(show_group=show_group))
return list(te.format_exception_only(show_group=show_group, colorize=colorize))


# -- not official API but folk probably use these two functions.
Expand Down

0 comments on commit 40df02e

Please sign in to comment.