Skip to content

Commit

Permalink
Use a configured logger instead of print statement
Browse files Browse the repository at this point in the history
Add find-links section to setup.cfg for easier install
Change LOGLEVEL env variable to LOG_LEVEL (in concordance with ark)
  • Loading branch information
jsilter committed May 22, 2024
1 parent 6618b4f commit 4dabda6
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 19 deletions.
22 changes: 7 additions & 15 deletions scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import pickle

import sybil.utils.logging_utils
from sybil import Serie, Sybil, visualize_attentions

script_directory = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -52,29 +53,20 @@ def _get_parser():
help="Name of the model to use for prediction. Default: sybil_ensemble",
)

parser.add_argument("-l", "--log", "--loglevel", default="INFO", dest="loglevel")
parser.add_argument("-l", "--log", "--loglevel", "--log-level",
default="INFO", dest="loglevel")

return parser


def logging_basic_config(args):
info_fmt = "[%(asctime)s] - %(message)s"
debug_fmt = "[%(asctime)s] [%(filename)s:%(lineno)d] %(levelname)s - %(message)s"
fmt = debug_fmt if args.loglevel.upper() == "DEBUG" else info_fmt

logging.basicConfig(
format=fmt, datefmt="%Y-%m-%d %H:%M:%S", level=args.loglevel.upper()
)


def inference(
image_dir,
output_dir,
model_name="sybil_ensemble",
return_attentions=False,
file_type="auto",
):
logger = logging.getLogger("inference")
logger = sybil.utils.logging_utils.get_logger()

input_files = os.listdir(image_dir)
input_files = [os.path.join(image_dir, x) for x in input_files if not x.startswith(".")]
Expand All @@ -94,11 +86,11 @@ def inference(

num_files = len(input_files)

logger.debug(f"Beginning prediction using {num_files} {file_type} files from {image_dir}")

# Load a trained model
model = Sybil(model_name)

logger.debug(f"Beginning prediction using {num_files} {file_type} files from {image_dir}")

# Get risk scores
serie = Serie(input_files, file_type=file_type)
series = [serie]
Expand Down Expand Up @@ -130,7 +122,7 @@ def inference(

def main():
args = _get_parser().parse_args()
logging_basic_config(args)
sybil.utils.logging_utils.configure_logger(args.loglevel)

os.makedirs(args.output_dir, exist_ok=True)

Expand Down
4 changes: 2 additions & 2 deletions scripts/run_inference_demo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ demo_scan_dir=sybil_demo_data
if [ ! -d "$demo_scan_dir" ]; then
# Download example data
curl -L -o sybil_example.zip "https://www.dropbox.com/scl/fi/covbvo6f547kak4em3cjd/sybil_example.zip?rlkey=7a13nhlc9uwga9x7pmtk1cf1c&dl=1"
tar -xf sybil_example.zip
unzip -q sybil_example.zip
fi

python3 scripts/inference.py \
--loglevel DEBUG \
--output-dir demo_prediction \
--return-attentions \
$demo_scan_dir
$demo_scan_dir
5 changes: 4 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ platforms = any
classifiers =
Programming Language :: Python

[easy_install]
find_links =
https://download.pytorch.org/whl/cu117/torch_stable.html

[options]
zip_safe = False
packages = find:
Expand All @@ -32,7 +36,6 @@ python_requires = >=3.8
# Version specifiers like >=2.2,<3.0 avoid problems due to API changes in
# new major versions. This works if the required packages follow Semantic Versioning.
# For more information, check out https://semver.org/.
# Use --find-links https://download.pytorch.org/whl/cu117/torch_stable.html for torch libraries
install_requires =
importlib-metadata; python_version>="3.8"
numpy==1.24.1
Expand Down
1 change: 1 addition & 0 deletions sybil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
from sybil.model import Sybil
from sybil.serie import Serie
from sybil.utils.visualization import visualize_attentions
import sybil.utils.logging_utils

__all__ = ["Sybil", "Serie", "visualize_attentions"]
5 changes: 4 additions & 1 deletion sybil/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from sybil.serie import Serie
from sybil.models.sybil import SybilNet
from sybil.utils.logging_utils import get_logger
from sybil.utils.device_utils import get_default_device
from sybil.utils.metrics import get_survival_metrics


Expand Down Expand Up @@ -182,6 +184,7 @@ def __init__(
By default uses GPU, if available.
"""
self._logger = get_logger()
# Download if needed
if isinstance(name_or_path, str) and (name_or_path in NAME_TO_FILE):
name_or_path, calibrator_path = download_sybil(name_or_path, cache)
Expand Down Expand Up @@ -240,7 +243,7 @@ def load_model(self, path):

# Set eval
model.eval()
print(f"Loaded model from {path}")
self._logger.info(f"Loaded model from {path}")
return model

def _calibrate(self, scores: np.ndarray) -> np.ndarray:
Expand Down
68 changes: 68 additions & 0 deletions sybil/utils/logging_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import logging
import os


LOGGER_NAME = "sybil"
LOGLEVEL_KEY = "LOG_LEVEL"


def _get_formatter(loglevel="INFO"):
warn_fmt = "[%(asctime)s] %(levelname)s -%(message)s"
debug_fmt = "[%(asctime)s] [%(filename)s:%(lineno)d] %(levelname)s - %(message)s"
fmt = debug_fmt if loglevel.upper() in {"DEBUG"} else warn_fmt
return logging.Formatter(
fmt=fmt,
datefmt="%Y-%b-%d %H:%M:%S %Z",
)


def remove_all_handlers(logger):
while logger.hasHandlers():
logger.removeHandler(logger.handlers[0])


def configure_logger(loglevel=None, logger_name=LOGGER_NAME, logfile=None):
"""Do basic logger configuration and set our main logger"""

# Set as environment variable so other processes can retrieve it
if loglevel is None:
loglevel = os.environ.get(LOGLEVEL_KEY, "WARNING")
else:
os.environ[LOGLEVEL_KEY] = loglevel

logger = logging.getLogger(logger_name)
logger.setLevel(loglevel)
remove_all_handlers(logger)
logger.propagate = False

formatter = _get_formatter(loglevel)
def _prep_handler(handler):
for ex_handler in logger.handlers:
if type(ex_handler) == type(handler):
# Remove old handler, don't want to double-handle
logger.removeHandler(ex_handler)
handler.setLevel(loglevel)
handler.setFormatter(formatter)
logger.addHandler(handler)

sh = logging.StreamHandler()
_prep_handler(sh)

if logfile is not None:
fh = logging.FileHandler(logfile, mode="a")
_prep_handler(fh)

return logger


def get_logger(base_name=LOGGER_NAME):
"""
Return a logger.
Use a different logger in each subprocess, though they should all have the same log level.
"""
pid = os.getpid()
logger_name = f"{base_name}-process-{pid}"
logger = logging.getLogger(logger_name)
if not logger.hasHandlers():
configure_logger(logger_name=logger_name)
return logger

0 comments on commit 4dabda6

Please sign in to comment.