Skip to content

Commit

Permalink
chore: add typestubs and fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
timonviola committed Feb 6, 2025
1 parent 5dc950f commit 8784b5c
Show file tree
Hide file tree
Showing 358 changed files with 36,973 additions and 7 deletions.
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@ repos:
hooks:
- id: typos
stages: [commit]
exclude: "^typings/.*"
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.10.1 # Use the sha / tag you want to point at
hooks:
- id: mypy
language: system
pass_filenames: false
args: ['.']
- repo: https://github.com/compilerla/conventional-pre-commit
rev: v3.3.0
hooks:
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ venvPath = ".venv"
venv = "dev"

[tool.ruff]
exclude = [
"typings/*"
]
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
Expand Down Expand Up @@ -304,7 +307,8 @@ disable_error_code = [
"annotation-unchecked",
]
exclude = [
"examles/*",
"^examples/*",
"^typings/*"
]

[tool.versioningit]
Expand Down
8 changes: 5 additions & 3 deletions src/dagcellent/operators/mlflow/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import functools
import logging
from collections.abc import Callable
from typing import TYPE_CHECKING, TypeVar
from warnings import warn

Expand All @@ -15,6 +14,8 @@

if TYPE_CHECKING:
# NOTE ruff fails for this check
from collections.abc import Callable

import mlflow.entities.model_registry # noqa: TCH004

from dagcellent.operators.mlflow._utils import MlflowModelStage
Expand All @@ -31,8 +32,9 @@ def _mlflow_request_wrapper(query: Callable[P, T]) -> T:
try:
res = query()
except mlflow.MlflowException as exc:
_LOGGER.error("Error during mlflow query.", exc_info=exc)
raise mlflow.MlflowException from exc
_msg = "Error during mlflow query."
_LOGGER.error(_msg, exc_info=exc)
raise mlflow.MlflowException(_msg) from exc
return res


Expand Down
2 changes: 1 addition & 1 deletion tests/dags/test_msql_reflect_otherdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
execute = SQLExecuteQueryOperator(
task_id="execute_query",
conn_id=CONN_ID,
sql=reflect_table.output,
sql=reflect_table.output, # type: ignore
database="model",
)

Expand Down
2 changes: 1 addition & 1 deletion tests/dags/test_msql_reflect_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
execute = SQLExecuteQueryOperator(
task_id="execute_query",
conn_id=CONN_ID,
sql=reflect_table.output,
sql=reflect_table.output, # type: ignore
database="model",
)

Expand Down
2 changes: 1 addition & 1 deletion tests/dags/test_mssql_reflect.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
create_pet_table = SQLExecuteQueryOperator(
task_id="create_pet_table",
conn_id=CONN_ID,
sql=reflect_table.output,
sql=reflect_table.output, # type: ignore
)

reflect_table >> create_pet_table
28 changes: 28 additions & 0 deletions typings/airflow/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
This type stub file was generated by pyright.
"""

import os
import sys
from typing import TYPE_CHECKING

from airflow import settings
from airflow.models.dag import DAG
from airflow.models.dataset import Dataset
from airflow.models.xcom_arg import XComArg

__version__ = ...
if os.environ.get("_AIRFLOW_PATCH_GEVENT"): ...
if sys.platform == "win32": ...
__all__ = ["__version__", "DAG", "Dataset", "XComArg"]
__path__ = ...
if not os.environ.get("_AIRFLOW__AS_LIBRARY", None): ...
__lazy_imports: dict[str, tuple[str, str, bool]] = ...
if TYPE_CHECKING: ...

def __getattr__(name: str): # -> bool | Any | ModuleType:
...

if not settings.LAZY_LOAD_PROVIDERS:
manager = ...
if not settings.LAZY_LOAD_PLUGINS: ...
3 changes: 3 additions & 0 deletions typings/airflow/api/common/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""
12 changes: 12 additions & 0 deletions typings/airflow/api/common/airflow_health.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""

from typing import Any

HEALTHY = ...
UNHEALTHY = ...

def get_airflow_health() -> dict[str, Any]:
"""Get the health for Airflow metadatabase, scheduler and triggerer."""
...
28 changes: 28 additions & 0 deletions typings/airflow/api/common/delete_dag.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
This type stub file was generated by pyright.
"""

from typing import TYPE_CHECKING

from airflow.utils.session import provide_session
from sqlalchemy.orm import Session

"""Delete DAGs APIs."""
if TYPE_CHECKING: ...
log = ...

@provide_session
def delete_dag(
dag_id: str, keep_records_in_log: bool = ..., session: Session = ...
) -> int:
"""
Delete a DAG by a dag_id.
:param dag_id: the dag_id of the DAG to delete
:param keep_records_in_log: whether keep records of the given dag_id
in the Log table in the backend database (for reasons like auditing).
The default value is True.
:param session: session used
:return count of deleted dags
"""
...
19 changes: 19 additions & 0 deletions typings/airflow/api/common/experimental/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
This type stub file was generated by pyright.
"""

from datetime import datetime
from typing import TYPE_CHECKING

from airflow.models import DagModel, DagRun

"""Experimental APIs."""
if TYPE_CHECKING: ...

def check_and_get_dag(dag_id: str, task_id: str | None = ...) -> DagModel:
"""Check DAG existence and in case it is specified that Task exists."""
...

def check_and_get_dagrun(dag: DagModel, execution_date: datetime) -> DagRun:
"""Get DagRun object and check that it exists."""
...
24 changes: 24 additions & 0 deletions typings/airflow/api/common/experimental/get_task_instance.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
This type stub file was generated by pyright.
"""

from datetime import datetime
from typing import TYPE_CHECKING

from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models import TaskInstance
from deprecated import deprecated

"""Task instance APIs."""
if TYPE_CHECKING: ...

@deprecated(
version="2.2.4",
reason="Use DagRun.get_task_instance instead",
category=RemovedInAirflow3Warning,
)
def get_task_instance(
dag_id: str, task_id: str, execution_date: datetime
) -> TaskInstance:
"""Return the task instance identified by the given dag_id, task_id and execution_date."""
...
Loading

0 comments on commit 8784b5c

Please sign in to comment.