Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix duplicate log #1661

Merged
merged 7 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion qlib/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,11 @@ def register_from_C(config, skip_register=True):
"filters": ["field_not_found"],
}
},
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}},
# Normally this should be set to `False` to avoid duplicated logging [1].
# However, due to bug in pytest, it requires log message to propagate to root logger to be captured by `caplog` [2].
# [1] https://github.com/microsoft/qlib/pull/1661
# [2] https://github.com/pytest-dev/pytest/issues/3697
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"], "propagate": False}},
# To let qlib work with other packages, we shouldn't disable existing loggers.
# Note that this param is default to True according to the documentation of logging.
"disable_existing_loggers": False,
Expand Down
31 changes: 18 additions & 13 deletions tests/rl/test_logger.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from random import randint, choice
from pathlib import Path
import logging

import re
from typing import Any, Tuple
Expand Down Expand Up @@ -69,6 +69,10 @@ def learn(self, batch):

def test_simple_env_logger(caplog):
set_log_with_config(C.logging_config)
# In order for caplog to capture log messages, we configure it here:
# allow logs from the qlib logger to be passed to the parent logger.
C.logging_config["loggers"]["qlib"]["propagate"] = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment

logging.config.dictConfig(C.logging_config)
for venv_cls_name in ["dummy", "shmem", "subproc"]:
writer = ConsoleWriter()
csv_writer = CsvWriter(Path(__file__).parent / ".output")
Expand All @@ -80,13 +84,12 @@ def test_simple_env_logger(caplog):
output_file = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
assert output_file.columns.tolist() == ["reward", "a", "c"]
assert len(output_file) >= 30

line_counter = 0
for line in caplog.text.splitlines():
line = line.strip()
if line:
line_counter += 1
assert re.match(r".*reward .* a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
assert re.match(r".*reward .* {2}a .* \(([456])\.\d+\) {2}c .* \((14|15|16)\.\d+\)", line)
assert line_counter >= 3


Expand Down Expand Up @@ -137,15 +140,17 @@ def learn(self, batch):

def test_logger_with_env_wrapper():
with DataQueue(list(range(20)), shuffle=False) as data_iterator:
env_wrapper_factory = lambda: EnvWrapper(
SimpleSimulator,
DummyStateInterpreter(),
DummyActionInterpreter(),
data_iterator,
logger=LogCollector(LogLevel.DEBUG),
)

# loglevel can be debug here because metrics can all dump into csv

def env_wrapper_factory():
return EnvWrapper(
SimpleSimulator,
DummyStateInterpreter(),
DummyActionInterpreter(),
data_iterator,
logger=LogCollector(LogLevel.DEBUG),
)

# loglevel can be debugged here because metrics can all dump into csv
# otherwise, csv writer might crash
csv_writer = CsvWriter(Path(__file__).parent / ".output", loglevel=LogLevel.DEBUG)
venv = vectorize_env(env_wrapper_factory, "shmem", 4, csv_writer)
Expand All @@ -155,7 +160,7 @@ def test_logger_with_env_wrapper():

output_df = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
assert len(output_df) == 20
# obs has a increasing trend
# obs has an increasing trend
assert output_df["obs"].to_numpy()[:10].sum() < output_df["obs"].to_numpy()[10:].sum()
assert (output_df["test_a"] == 233).all()
assert (output_df["test_b"] == 200).all()
Expand Down
Loading