Skip to content

Commit

Permalink
first pass as better pre commit config
Browse files Browse the repository at this point in the history
  • Loading branch information
emmyoop committed Jan 10, 2024
1 parent 303481c commit cf96a21
Show file tree
Hide file tree
Showing 24 changed files with 158 additions and 218 deletions.
43 changes: 43 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Configuration for pre-commit hooks (see https://pre-commit.com/).

exclude: ^(dbt/common/events/types_pb2.py)

# Force all unspecified python hooks to run python 3.8
default_language_version:
python: python3
Expand All @@ -15,3 +17,44 @@ repos:
exclude_types:
- "markdown"
- id: check-case-conflict
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
- id: black
alias: black-check
stages: [manual]
args:
- "--check"
- "--diff"
# - repo: https://github.com/pycqa/flake8
# rev: 4.0.1
# hooks:
# - id: flake8
# - id: flake8
# alias: flake8-check
# stages: [manual]
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.4.1
# hooks:
# - id: mypy
# # N.B.: Mypy is... a bit fragile.
# #
# # By using `language: system` we run this hook in the local
# # environment instead of a pre-commit isolated one. This is needed
# # to ensure mypy correctly parses the project.

# # It may cause trouble
# # in that it adds environmental variables out of our control to the
# # mix. Unfortunately, there's nothing we can do about per pre-commit's
# # author.
# # See https://github.com/pre-commit/pre-commit/issues/730 for details.
# args: [--show-error-codes]
# files: ^dbt/common/
# language: system
# - id: mypy
# alias: mypy-check
# stages: [manual]
# args: [--show-error-codes, --pretty]
# files: ^dbt/common
# language: system
19 changes: 5 additions & 14 deletions dbt_common/clients/_jinja_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,7 @@ def _expect_block_close(self):
are quote and `%}` - nothing else can hide the %} and be valid jinja.
"""
while True:
end_match = self._expect_match(
'tag close ("%}")', QUOTE_START_PATTERN, TAG_CLOSE_PATTERN
)
end_match = self._expect_match('tag close ("%}")', QUOTE_START_PATTERN, TAG_CLOSE_PATTERN)
self.advance(end_match.end())
if end_match.groupdict().get("tag_close") is not None:
return
Expand Down Expand Up @@ -234,15 +232,11 @@ def handle_tag(self, match):
else:
self.advance(match.end())
self._expect_block_close()
return Tag(
block_type_name=block_type_name, block_name=block_name, start=start_pos, end=self.pos
)
return Tag(block_type_name=block_type_name, block_name=block_name, start=start_pos, end=self.pos)

def find_tags(self):
while True:
match = self._first_match(
BLOCK_START_PATTERN, COMMENT_START_PATTERN, EXPR_START_PATTERN
)
match = self._first_match(BLOCK_START_PATTERN, COMMENT_START_PATTERN, EXPR_START_PATTERN)
if match is None:
break

Expand All @@ -262,8 +256,7 @@ def find_tags(self):
yield self.handle_tag(match)
else:
raise DbtInternalError(
"Invalid regex match in next_block, expected block start, "
"expr start, or comment start"
"Invalid regex match in next_block, expected block start, " "expr start, or comment start"
)

def __iter__(self):
Expand Down Expand Up @@ -355,6 +348,4 @@ def find_blocks(self, allowed_blocks=None, collect_raw_data=True):
yield BlockData(raw_data)

def lex_for_blocks(self, allowed_blocks=None, collect_raw_data=True):
return list(
self.find_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data)
)
return list(self.find_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data))
8 changes: 2 additions & 6 deletions dbt_common/clients/agate_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ def build_type_tester(
agate.data_types.Date(null_values=("null", ""), date_format="%Y-%m-%d"),
agate.data_types.DateTime(null_values=("null", ""), datetime_format="%Y-%m-%d %H:%M:%S"),
ISODateTime(null_values=("null", "")),
agate.data_types.Boolean(
true_values=("true",), false_values=("false",), null_values=("null", "")
),
agate.data_types.Boolean(true_values=("true",), false_values=("false",), null_values=("null", "")),
agate.data_types.Text(null_values=string_null_values),
]
force = {k: agate.data_types.Text(null_values=string_null_values) for k in text_columns}
Expand Down Expand Up @@ -132,9 +130,7 @@ def table_from_data_flat(data, column_names: Iterable[str]) -> agate.Table:

rows.append(row)

return table_from_rows(
rows=rows, column_names=column_names, text_only_columns=text_only_columns
)
return table_from_rows(rows=rows, column_names=column_names, text_only_columns=text_only_columns)


def empty_table():
Expand Down
8 changes: 2 additions & 6 deletions dbt_common/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,7 @@ def __getitem__(self, name):

def __getattr__(self, name):
if name == "name" or _is_dunder_name(name):
raise AttributeError(
"'{}' object has no attribute '{}'".format(type(self).__name__, name)
)
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name))

self.name = name

Expand Down Expand Up @@ -500,6 +498,4 @@ def extract_toplevel_blocks(
:return: A list of `BlockTag`s matching the allowed block types and (if
`collect_raw_data` is `True`) `BlockData` objects.
"""
return BlockIterator(data).lex_for_blocks(
allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data
)
return BlockIterator(data).lex_for_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data)
8 changes: 2 additions & 6 deletions dbt_common/contracts/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ def __setitem__(self, key, value):

def __delitem__(self, key):
if hasattr(self, key):
msg = (
'Error, tried to delete config key "{}": Cannot delete ' "built-in keys"
).format(key)
msg = ('Error, tried to delete config key "{}": Cannot delete ' "built-in keys").format(key)
raise CompilationError(msg)
else:
del self._extra[key]
Expand Down Expand Up @@ -143,9 +141,7 @@ def _merge_dicts(cls, src: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, An
)
return result

def update_from(
self: T, data: Dict[str, Any], config_cls: Type[BaseConfig], validate: bool = True
) -> T:
def update_from(self: T, data: Dict[str, Any], config_cls: Type[BaseConfig], validate: bool = True) -> T:
"""Given a dict of keys, update the current config from them, validate
it, and return a new config with the updated values
"""
Expand Down
8 changes: 2 additions & 6 deletions dbt_common/contracts/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,8 @@ class ColumnLevelConstraint(dbtClassMixin):
# It could be a predicate (check type), or a sequence sql keywords (e.g. unique type),
# so the vague naming of 'expression' is intended to capture this range.
expression: Optional[str] = None
warn_unenforced: bool = (
True # Warn if constraint cannot be enforced by platform but will be in DDL
)
warn_unsupported: bool = (
True # Warn if constraint is not supported by the platform and won't be in DDL
)
warn_unenforced: bool = True # Warn if constraint cannot be enforced by platform but will be in DDL
warn_unsupported: bool = True # Warn if constraint is not supported by the platform and won't be in DDL


@dataclass
Expand Down
8 changes: 2 additions & 6 deletions dbt_common/events/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ def __init__(self, *args, **kwargs) -> None:
if class_name == "Formatting" and len(args) > 0:
kwargs["msg"] = args[0]
args = ()
assert (
len(args) == 0
), f"[{class_name}] Don't use positional arguments when constructing logging events"
assert len(args) == 0, f"[{class_name}] Don't use positional arguments when constructing logging events"
if "base_msg" in kwargs:
kwargs["base_msg"] = str(kwargs["base_msg"])
if "msg" in kwargs:
Expand Down Expand Up @@ -94,9 +92,7 @@ def __getattr__(self, key):
return super().__getattribute__("pb_msg").__getattribute__(key)

def to_dict(self):
return MessageToDict(
self.pb_msg, preserving_proto_field_name=True, including_default_value_fields=True
)
return MessageToDict(self.pb_msg, preserving_proto_field_name=True, including_default_value_fields=True)

def to_json(self) -> str:
return MessageToJson(
Expand Down
4 changes: 1 addition & 3 deletions dbt_common/events/event_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ def fire_event(self, e: BaseEvent, level: Optional[EventLevel] = None) -> None:
callback(msg)

def add_logger(self, config: LoggerConfig) -> None:
logger = (
_JsonLogger(config) if config.line_format == LineFormat.Json else _TextLogger(config)
)
logger = _JsonLogger(config) if config.line_format == LineFormat.Json else _TextLogger(config)
self.loggers.append(logger)

def flush(self) -> None:
Expand Down
4 changes: 1 addition & 3 deletions dbt_common/events/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ def format_fancy_output_line(
else:
status_time = " in {execution_time:0.2f}s".format(execution_time=execution_time)

output = "{justified} [{status}{status_time}]".format(
justified=justified, status=status, status_time=status_time
)
output = "{justified} [{status}{status_time}]".format(justified=justified, status=status, status_time=status_time)

return output

Expand Down
22 changes: 5 additions & 17 deletions dbt_common/events/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,9 @@ def msg_to_dict(msg: EventMsg) -> dict:
)
except Exception as exc:
event_type = type(msg).__name__
fire_event(
Note(msg=f"type {event_type} is not serializable. {str(exc)}"), level=EventLevel.WARN
)
fire_event(Note(msg=f"type {event_type} is not serializable. {str(exc)}"), level=EventLevel.WARN)
# We don't want an empty NodeInfo in output
if (
"data" in msg_dict
and "node_info" in msg_dict["data"]
and msg_dict["data"]["node_info"]["node_name"] == ""
):
if "data" in msg_dict and "node_info" in msg_dict["data"] and msg_dict["data"]["node_info"]["node_name"] == "":
del msg_dict["data"]["node_info"]
return msg_dict

Expand All @@ -124,17 +118,13 @@ def warn_or_error(event, node=None) -> None:

# an alternative to fire_event which only creates and logs the event value
# if the condition is met. Does nothing otherwise.
def fire_event_if(
conditional: bool, lazy_e: Callable[[], BaseEvent], level: Optional[EventLevel] = None
) -> None:
def fire_event_if(conditional: bool, lazy_e: Callable[[], BaseEvent], level: Optional[EventLevel] = None) -> None:
if conditional:
fire_event(lazy_e(), level=level)


# a special case of fire_event_if, to only fire events in our unit/functional tests
def fire_event_if_test(
lazy_e: Callable[[], BaseEvent], level: Optional[EventLevel] = None
) -> None:
def fire_event_if_test(lazy_e: Callable[[], BaseEvent], level: Optional[EventLevel] = None) -> None:
fire_event_if(conditional=("pytest" in sys.modules), lazy_e=lazy_e, level=level)


Expand All @@ -150,9 +140,7 @@ def get_metadata_vars() -> Dict[str, str]:
global metadata_vars
if not metadata_vars:
metadata_vars = {
k[len(_METADATA_ENV_PREFIX) :]: v
for k, v in os.environ.items()
if k.startswith(_METADATA_ENV_PREFIX)
k[len(_METADATA_ENV_PREFIX) :]: v for k, v in os.environ.items() if k.startswith(_METADATA_ENV_PREFIX)
}
return metadata_vars

Expand Down
8 changes: 2 additions & 6 deletions dbt_common/events/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def send_to_logger(l, level: str, log_line: str):
elif level == "error":
l.error(log_line)
else:
raise AssertionError(
f"While attempting to log {log_line}, encountered the unhandled level: {level}"
)
raise AssertionError(f"While attempting to log {log_line}, encountered the unhandled level: {level}")


@dataclass
Expand Down Expand Up @@ -152,9 +150,7 @@ def create_debug_line(self, msg: EventMsg) -> str:
log_line = f"\n\n{separator} {ts} | {self.invocation_id} {separator}\n"
scrubbed_msg: str = self.scrubber(msg.info.msg) # type: ignore
level = msg.info.level
log_line += (
f"{self._get_color_tag()}{ts} [{level:<5}]{self._get_thread_name()} {scrubbed_msg}"
)
log_line += f"{self._get_color_tag()}{ts} [{level:<5}]{self._get_thread_name()} {scrubbed_msg}"
return log_line

def _get_color_tag(self) -> str:
Expand Down
Loading

0 comments on commit cf96a21

Please sign in to comment.