Skip to content

Commit

Permalink
Add more proper typing to StreamLogWriter (apache#37347)
Browse files Browse the repository at this point in the history
  • Loading branch information
Grub4K authored Nov 19, 2024
1 parent 8c91fce commit 931a6b9
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 11 deletions.
40 changes: 32 additions & 8 deletions airflow/utils/log/logging_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import enum
import logging
import sys
from io import IOBase
from io import TextIOBase, UnsupportedOperation
from logging import Handler, StreamHandler
from typing import IO, TYPE_CHECKING, Any, Optional, TypeVar, cast

Expand Down Expand Up @@ -150,18 +150,39 @@ def supports_external_link(self) -> bool:

# We have to ignore typing errors here because Python I/O classes are a mess, and they do not
# have the same type hierarchy defined as the `typing.IO` - they violate Liskov Substitution Principle
# While it is ok to make your class derive from IOBase (and its good thing to do as they provide
# While it is ok to make your class derive from TextIOBase (and its good thing to do as they provide
# base implementation for IO-implementing classes, it's impossible to make them work with
# IO generics (and apparently it has not even been intended)
# See more: https://giters.com/python/typeshed/issues/6077
class StreamLogWriter(IOBase, IO[str]): # type: ignore[misc]
class StreamLogWriter(TextIOBase, IO[str]): # type: ignore[misc]
"""
Allows to redirect stdout and stderr to logger.
:param log: The log level method to write to, ie. log.debug, log.warning
:param logger: The logging.Logger instance to write to
:param level: The log level method to write to, ie. logging.DEBUG, logging.WARNING
"""

encoding: None = None
encoding = "undefined"

@property
def mode(self):
return "w"

@property
def name(self):
return f"<logger: {self.logger.name}>"

def writable(self):
return True

def readable(self):
return False

def seekable(self):
return False

def fileno(self):
raise UnsupportedOperation("fileno")

def __init__(self, logger, level):
self.logger = logger
Expand Down Expand Up @@ -196,11 +217,14 @@ def write(self, message):
:param message: message to log
"""
if not message.endswith("\n"):
if message.endswith("\n"):
message = message.rstrip()
self._buffer += message
else:
self._buffer += message.rstrip()
self.flush()
else:
self._buffer += message

return len(message)

def flush(self):
"""Ensure all logging output has been flushed."""
Expand Down
15 changes: 12 additions & 3 deletions tests/utils/test_logging_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,13 @@ def test_write(self):
log = StreamLogWriter(logger, 1)

msg = "test_message"
log.write(msg)
length = log.write(msg)

assert length == len(msg)
assert log._buffer == msg

log.write(" \n")
length = log.write(" \n")
assert length == 0
logger.log.assert_called_once_with(1, msg)

assert log._buffer == ""
Expand Down Expand Up @@ -168,7 +170,7 @@ def test_encoding(self):
logger.log = mock.MagicMock()

log = StreamLogWriter(logger, 1)
assert log.encoding is None
assert log.encoding == "undefined"

def test_iobase_compatibility(self):
log = StreamLogWriter(None, 1)
Expand All @@ -177,6 +179,13 @@ def test_iobase_compatibility(self):
# has no specific effect
log.close()

def test_textio_compatibility(self):
log = StreamLogWriter(logging.getLogger(__name__), 1)

# Make sure some common operations won't fail
assert "b" not in log.mode
assert isinstance(log.name, str)


@pytest.mark.parametrize(["maintain_propagate"], [[SetContextPropagate.MAINTAIN_PROPAGATE], [None]])
def test_set_context_propagation(parent_child_handlers, child_logger, maintain_propagate):
Expand Down

0 comments on commit 931a6b9

Please sign in to comment.