Skip to content

Commit

Permalink
Automatic Logger Level Update (pytorch#1206)
Browse files Browse the repository at this point in the history
* automatic logger level update

* add config log level unit test

* fix lint

* fix import errors

* revert init_logging move

* fix import error

* refactor out logging

* move logging back to util
  • Loading branch information
williamwen42 committed Sep 15, 2022
1 parent 0fde95c commit bf82d1b
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 20 deletions.
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[pytest]
testpaths =
test
log_cli = True
log_cli = False
log_cli_level = INFO
16 changes: 16 additions & 0 deletions test/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import dataclasses
import dis
import enum
import logging
import math
import os
import sys
Expand Down Expand Up @@ -2330,6 +2331,21 @@ def f3(x):
f3(torch.ones(6))
self.assertEqual(cnt.frame_count, 0)

def test_config_log_level(self):
@torchdynamo.optimize("eager")
def fn(a, b):
return a + b

with self.assertLogs(logger="torchdynamo", level=logging.DEBUG) as log:
torchdynamo.config.log_level = logging.DEBUG
fn(torch.randn(10), torch.randn(10))
cur_len = len(log)
self.assertGreater(cur_len, 0)

torchdynamo.config.log_level = logging.WARNING
fn(torch.randn(10), torch.randn(10))
self.assertEqual(cur_len, len(log))


class TestTracer(JitTestCase):
def test_jit_save(self):
Expand Down
6 changes: 3 additions & 3 deletions test/test_recompile_ux.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def model(input):
with unittest.mock.patch.object(
torchdynamo.config, "cache_size_limit", expected_recompiles
):
with self.assertLogs(level="WARNING") as logs:
with self.assertLogs(logger="torchdynamo", level="WARNING") as logs:
for _ in range(10):
bsz = torch.randint(low=0, high=1000, size=())
x = torch.randn((bsz, 3, 4))
Expand Down Expand Up @@ -152,7 +152,7 @@ def cache_fail_test(cached_input, missed_input, expected_failure):
# warmup
opt_func(cached_input)

with self.assertLogs(level="WARNING") as logs:
with self.assertLogs(logger="torchdynamo", level="WARNING") as logs:
opt_func = torchdynamo.optimize("eager")(func)
opt_func(missed_input)
self.assert_single_log_contains(logs, expected_failure)
Expand Down Expand Up @@ -190,7 +190,7 @@ def func(a, b):
# warmup
opt_func(a, b)

with self.assertLogs(level="WARNING") as logs:
with self.assertLogs(logger="torchdynamo", level="WARNING") as logs:
opt_func = torchdynamo.optimize("eager")(func)
opt_func(a, 1)
self.assert_single_log_contains(
Expand Down
11 changes: 9 additions & 2 deletions torchdynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import torch

import torchdynamo.utils

try:
import torch._prims
import torch._refs
Expand All @@ -21,13 +23,15 @@
# INFO print compiled functions + graphs
# WARN print warnings (including graph breaks)
# ERROR print exceptions (and what user code was being processed when it occurred)
# NOTE: changing log_level will automatically update the levels of all torchdynamo loggers
log_level = logging.WARNING
# Verbose will print full stack traces on warnings and errors
verbose = False

# the name of a file to write the logs to
log_file_name = None

# Verbose will print full stack traces on warnings and errors
verbose = False

# verify the correctness of optimized backend
verify_correctness = False

Expand Down Expand Up @@ -130,6 +134,9 @@ class _AccessLimitingConfig(ModuleType):
def __setattr__(self, name, value):
if name not in _allowed_config_names:
raise AttributeError(f"{__name__}.{name} does not exist")
# automatically set logger level whenever config.log_level is modified
if name == "log_level":
torchdynamo.utils.set_loggers_level(value)
return object.__setattr__(self, name, value)


Expand Down
2 changes: 0 additions & 2 deletions torchdynamo/testing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import contextlib
import dis
import functools
import logging
import os.path
import types
import unittest
Expand Down Expand Up @@ -194,7 +193,6 @@ def tearDownClass(cls):
@classmethod
def setUpClass(cls):
cls._exit_stack = contextlib.ExitStack()
cls._exit_stack.enter_context(patch.object(config, "log_level", logging.DEBUG))
cls._exit_stack.enter_context(
patch.object(config, "raise_on_backend_error", True)
)
Expand Down
37 changes: 25 additions & 12 deletions torchdynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import gc
import inspect
import itertools
import logging
import logging.config
import math
import operator
Expand All @@ -29,9 +30,8 @@
from torch import fx
from torch.nn.modules.lazy import LazyModuleMixin

import torchdynamo.config

from . import config
import torchdynamo
import torchdynamo.config as config

counters = collections.defaultdict(collections.Counter)
troubleshooting_url = (
Expand Down Expand Up @@ -117,6 +117,20 @@ def fmt_fn(values, item_fn=lambda x: x):
return headers, values


# Return all loggers that torchdynamo is responsible for
def get_loggers():
return [
logging.getLogger("torchdynamo"),
logging.getLogger("torchinductor"),
]


# Set the level of all loggers that torchdynamo is responsible for
def set_loggers_level(level):
for logger in get_loggers():
logger.setLevel(level)


LOGGING_CONFIG = {
"version": 1,
"formatters": {
Expand Down Expand Up @@ -146,18 +160,17 @@ def fmt_fn(values, item_fn=lambda x: x):
}


# initialize torchdynamo loggers
def init_logging():
if "PYTEST_CURRENT_TEST" not in os.environ:
logging.config.dictConfig(LOGGING_CONFIG)
td_logger = logging.getLogger("torchdynamo")
td_logger.setLevel(config.log_level)
ti_logger = logging.getLogger("torchinductor")
ti_logger.setLevel(config.log_level)
if config.log_file_name is not None:
log_file = logging.FileHandler(config.log_file_name)
log_file.setLevel(config.log_level)
td_logger.addHandler(log_file)
ti_logger.addHandler(log_file)
for logger in get_loggers():
logger.addHandler(log_file)

set_loggers_level(config.log_level)


# filter out all frames after entering dynamo
Expand Down Expand Up @@ -793,14 +806,14 @@ def format_func_info(code):

@contextlib.contextmanager
def disable_cache_limit():
prior = torchdynamo.config.cache_size_limit
torchdynamo.config.cache_size_limit = sys.maxsize
prior = config.cache_size_limit
config.cache_size_limit = sys.maxsize

try:
yield
finally:
pass
torchdynamo.config.cache_size_limit = prior
config.cache_size_limit = prior


# map from transformed code back to original user code
Expand Down

0 comments on commit bf82d1b

Please sign in to comment.