diff --git a/qlib/config.py b/qlib/config.py index 7910dab736..2fa7d4535a 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -173,7 +173,11 @@ def register_from_C(config, skip_register=True): "filters": ["field_not_found"], } }, - "loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}}, + # Normally this should be set to `False` to avoid duplicated logging [1]. + # However, due to bug in pytest, it requires log message to propagate to root logger to be captured by `caplog` [2]. + # [1] https://github.com/microsoft/qlib/pull/1661 + # [2] https://github.com/pytest-dev/pytest/issues/3697 + "loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"], "propagate": False}}, # To let qlib work with other packages, we shouldn't disable existing loggers. # Note that this param is default to True according to the documentation of logging. "disable_existing_loggers": False, diff --git a/tests/rl/test_logger.py b/tests/rl/test_logger.py index c8ceca92ad..e100e5046b 100644 --- a/tests/rl/test_logger.py +++ b/tests/rl/test_logger.py @@ -1,8 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - from random import randint, choice from pathlib import Path +import logging import re from typing import Any, Tuple @@ -69,6 +69,10 @@ def learn(self, batch): def test_simple_env_logger(caplog): set_log_with_config(C.logging_config) + # In order for caplog to capture log messages, we configure it here: + # allow logs from the qlib logger to be passed to the parent logger. + C.logging_config["loggers"]["qlib"]["propagate"] = True + logging.config.dictConfig(C.logging_config) for venv_cls_name in ["dummy", "shmem", "subproc"]: writer = ConsoleWriter() csv_writer = CsvWriter(Path(__file__).parent / ".output") @@ -80,13 +84,12 @@ def test_simple_env_logger(caplog): output_file = pd.read_csv(Path(__file__).parent / ".output" / "result.csv") assert output_file.columns.tolist() == ["reward", "a", "c"] assert len(output_file) >= 30 - line_counter = 0 for line in caplog.text.splitlines(): line = line.strip() if line: line_counter += 1 - assert re.match(r".*reward .* a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line) + assert re.match(r".*reward .* {2}a .* \(([456])\.\d+\) {2}c .* \((14|15|16)\.\d+\)", line) assert line_counter >= 3 @@ -137,15 +140,17 @@ def learn(self, batch): def test_logger_with_env_wrapper(): with DataQueue(list(range(20)), shuffle=False) as data_iterator: - env_wrapper_factory = lambda: EnvWrapper( - SimpleSimulator, - DummyStateInterpreter(), - DummyActionInterpreter(), - data_iterator, - logger=LogCollector(LogLevel.DEBUG), - ) - - # loglevel can be debug here because metrics can all dump into csv + + def env_wrapper_factory(): + return EnvWrapper( + SimpleSimulator, + DummyStateInterpreter(), + DummyActionInterpreter(), + data_iterator, + logger=LogCollector(LogLevel.DEBUG), + ) + + # loglevel can be debugged here because metrics can all dump into csv # otherwise, csv writer might crash csv_writer = CsvWriter(Path(__file__).parent / ".output", loglevel=LogLevel.DEBUG) venv = vectorize_env(env_wrapper_factory, "shmem", 4, csv_writer) @@ -155,7 +160,7 @@ def test_logger_with_env_wrapper(): output_df = pd.read_csv(Path(__file__).parent / ".output" / "result.csv") assert len(output_df) == 20 - # obs has a increasing trend + # obs has an increasing trend assert output_df["obs"].to_numpy()[:10].sum() < output_df["obs"].to_numpy()[10:].sum() assert (output_df["test_a"] == 233).all() assert (output_df["test_b"] == 200).all()