Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some mypy fixes for file-related packages #197

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def validate_authorized_yadocs(data: dict) -> None:
)


def validate_yadocs_data(data):
def validate_yadocs_data(data: dict) -> None:
if not ((data["public_link"] is None) ^ (data["private_path"] is None)):
raise ValueError("Expected exactly one of [`private_path`, `public_link`] to be specified")
if data["public_link"] is None and data["oauth_token"] is None and data["connection_id"] is None:
Expand Down Expand Up @@ -110,11 +110,11 @@ def get_obj_type(self, obj: dict[str, Any]) -> str:
assert isinstance(type_field, FileType)
return type_field.name

def get_data_type(self, data):
data_type = data.get(self.type_field)
def get_data_type(self, data: dict[str, Any]) -> str:
default_data_type = FileType.gsheets.value
data_type = data.get(self.type_field, default_data_type)
if self.type_field not in data:
data[self.type_field] = FileType.gsheets.value
data_type = FileType.gsheets.value
data[self.type_field] = default_data_type
if self.type_field in data and self.type_field_remove:
data.pop(self.type_field)
return data_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,6 @@ async def post(self) -> web.StreamResponse:
class DocumentsView(FileUploaderBaseView):
REQUIRED_RESOURCES: ClassVar[frozenset[RequiredResource]] = frozenset() # Don't skip CSRF check

FILE_TYPE_TO_DATA_FILE_PREPARER_MAP: dict[
FileType, Callable[[str, RedisModelManager, Optional[str]], Awaitable[DataFile]]
] = {
FileType.yadocs: yadocs_data_file_preparer,
}

async def post(self) -> web.StreamResponse:
req_data = await self._load_post_request_schema_data(files_schemas.FileDocumentsRequestSchema)

Expand All @@ -191,7 +185,7 @@ async def post(self) -> web.StreamResponse:
connection_id: Optional[str] = req_data["connection_id"]
authorized: bool = req_data["authorized"]

df = await self.FILE_TYPE_TO_DATA_FILE_PREPARER_MAP[file_type](
df = await yadocs_data_file_preparer(
oauth_token=oauth_token,
private_path=private_path,
public_link=public_link,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class ProcessExcelTask(BaseTaskMeta):
name = TaskName("process_excel")

file_id: str = attr.ib()
exec_mode: Optional[TaskExecutionMode] = attr.ib(default=TaskExecutionMode.BASIC)
exec_mode: TaskExecutionMode = attr.ib(default=TaskExecutionMode.BASIC)
tenant_id: Optional[str] = attr.ib(default=None)
connection_id: Optional[str] = attr.ib(default=None)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ async def run(self) -> TaskResult:
old_s3_filename = source.s3_filename
else:
assert old_tenant_id
assert conn.uuid
assert source.s3_filename_suffix
old_s3_filename = "_".join(
(old_tenant_id, conn.uuid, source.s3_filename_suffix)
)
Expand All @@ -282,6 +284,7 @@ async def run(self) -> TaskResult:
if len(old_fname_parts) >= 2 and all(part for part in old_fname_parts):
# assume that first part is old tenant id
old_tenant_id = old_fname_parts[0]
assert old_tenant_id
old_tenants.add(old_tenant_id)

s3_filename_suffix = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def run(self) -> TaskResult:
download_error = FileProcessingError.from_exception(e)
dfile.status = FileProcessingStatus.failed
dfile.error = download_error
if self.meta.exec_mode != TaskExecutionMode.BASIC:
if self.meta.exec_mode != TaskExecutionMode.BASIC and dfile.sources is not None:
for src in dfile.sources:
src.status = FileProcessingStatus.failed
src.error = download_error
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DataSource,
ExcelFileSourceSettings,
FileProcessingError,
FileSourceSettings,
YaDocsFileSourceSettings,
YaDocsUserSourceDataSourceProperties,
)
Expand Down Expand Up @@ -53,7 +54,7 @@ class ProcessExcelTask(BaseExecutorTask[task_interface.ProcessExcelTask, FileUpl

async def run(self) -> TaskResult:
dfile: Optional[DataFile] = None
sources_to_update_by_sheet_id: dict[int, list[DataSource]] = defaultdict(list)
sources_to_update_by_sheet_id: dict[str, list[DataSource]] = defaultdict(list)
usm = self._ctx.get_async_usm()
task_processor = self._ctx.make_task_processor(self._request_id)
redis = self._ctx.redis_service.get_redis()
Expand Down Expand Up @@ -104,6 +105,9 @@ async def run(self) -> TaskResult:
file_data = await resp.json()

for src in dfile.sources:
# sources can be pre-filled only during update, which can happen only to document at the moment
# TODO needs refactoring
assert isinstance(src.user_source_dsrc_properties, YaDocsUserSourceDataSourceProperties)
sources_to_update_by_sheet_id[src.user_source_dsrc_properties.sheet_id].append(src)

for spreadsheet in file_data:
Expand Down Expand Up @@ -158,7 +162,7 @@ async def run(self) -> TaskResult:
exc_to_save = ex if isinstance(ex, exc.DLFileUploaderBaseError) else exc.ParseFailed()
src.error = FileProcessingError.from_exception(exc_to_save)
connection_error_tracker.add_error(src.id, src.error)
sheet_settings = None
sheet_settings: Optional[FileSourceSettings] = None

for src in sheet_data_sources:
if src.is_applicable:
Expand Down
12 changes: 6 additions & 6 deletions lib/dl_task_processor/dl_task_processor/arq_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@
class ArqCronWrapper:
_task: BaseTaskMeta = attr.ib()
__qualname__ = "ArqCronWrapper"
# because of asyncio.iscoroutinefunction in the arq core
_is_coroutine = asyncio.coroutines._is_coroutine
# special asyncio marker; because of asyncio.iscoroutinefunction in the arq core
_is_coroutine = asyncio.coroutines._is_coroutine # type: ignore

async def __call__(self, ctx: Dict[Any, Any], *args: Any, **kwargs: Any) -> Any: # pragma: no cover
return await arq_base_task(
Expand Down Expand Up @@ -86,7 +86,7 @@ def create_arq_redis_settings(settings: _BIRedisSettings) -> ArqRedisSettings:
port=settings.PORT,
password=settings.PASSWORD,
database=settings.DB,
ssl=settings.SSL,
ssl=settings.SSL or False,
)
elif settings.MODE == RedisMode.sentinel:
redis_targets = [(host, settings.PORT) for host in settings.HOSTS]
Expand All @@ -96,7 +96,7 @@ def create_arq_redis_settings(settings: _BIRedisSettings) -> ArqRedisSettings:
sentinel=True,
sentinel_master=settings.CLUSTER_NAME,
database=settings.DB,
ssl=settings.SSL,
ssl=settings.SSL or False,
)
else:
raise ValueError(f"Unknown redis mode {settings.MODE}")
Expand Down Expand Up @@ -132,7 +132,6 @@ def make_cron_task(task: BaseTaskMeta, schedule: CronSchedule) -> CronTask:

async def arq_base_task(context: Dict, params: Dict) -> None:
LOGGER.info("Run arq task with params %s", params)
context[EXECUTOR_KEY]: Executor
# transition
# i'll delete it later
if "task_params" in params:
Expand All @@ -156,7 +155,8 @@ async def arq_base_task(context: Dict, params: Dict) -> None:
instance_id=instance_id,
attempt=context["job_try"] - 1, # it starts from 1 o_O
)
job_result = await context[EXECUTOR_KEY].run_job(task_instance)
executor: Executor = context[EXECUTOR_KEY]
job_result = await executor.run_job(task_instance)
if isinstance(job_result, Retry):
# https://arq-docs.helpmanual.io/#retrying-jobs-and-cancellation
raise ArqRetry(
Expand Down
4 changes: 2 additions & 2 deletions lib/dl_task_processor/dl_task_processor/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def set_state(self, task: TaskInstance, state: str) -> None:
pass

def get_state(self, task: TaskInstance) -> list:
pass
return []


# i will change it later
Expand Down Expand Up @@ -75,7 +75,7 @@ async def wait_task(task: TaskInstance, state: TaskState, timeout: float = 10, i
"""
timeout == the bottom line
"""
spent_time = 0
spent_time = 0.0
while spent_time < timeout:
current_state = state.get_state(task)
# Has the task reached the final state?
Expand Down
12 changes: 5 additions & 7 deletions lib/dl_task_processor/dl_task_processor/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Iterable
import enum
from typing import (
Any,
ClassVar,
Generic,
NewType,
Expand Down Expand Up @@ -56,13 +57,10 @@ class TaskInstance:
request_id: Optional[str] = attr.ib(default=None)


@attr.s
class BaseTaskMeta(metaclass=abc.ABCMeta):
name: ClassVar[TaskName]

def __init__(self, *args, **kwargs) -> None:
# lets trick typing
pass

def get_params(self, with_name: bool = False) -> dict:
if with_name:
return dict(
Expand Down Expand Up @@ -119,7 +117,7 @@ class Retry(TaskResult):

@attr.s
class BaseExecutorTask(Generic[_BASE_TASK_META_TV, _BASE_TASK_CONTEXT_TV], metaclass=abc.ABCMeta):
cls_meta: ClassVar[Type[_BASE_TASK_META_TV]]
cls_meta: ClassVar[Type[_BASE_TASK_META_TV]] # type: ignore[misc] # TODO: fix
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we ignoring this one specifically to make dl_task_processor green on mypy? if not, we probably should fix it with other ClassVat[Type[...]] errors

meta: _BASE_TASK_META_TV = attr.ib()
_ctx: _BASE_TASK_CONTEXT_TV = attr.ib()
_instance_id: InstanceID = attr.ib()
Expand Down Expand Up @@ -154,7 +152,7 @@ async def run(self) -> TaskResult:

@attr.s
class TaskRegistry:
_tasks: dict[TaskName, BaseExecutorTask] = attr.ib()
_tasks: dict[TaskName, Type[BaseExecutorTask]] = attr.ib()

@classmethod
def create(cls, tasks: Iterable[Type[BaseExecutorTask]]) -> "TaskRegistry":
Expand All @@ -163,7 +161,7 @@ def create(cls, tasks: Iterable[Type[BaseExecutorTask]]) -> "TaskRegistry":
) == sorted(list(set([t.name() for t in tasks]))), "Some tasks has the same name"
return cls(tasks={task.name(): task for task in tasks})

def get_task(self, name: TaskName) -> BaseExecutorTask:
def get_task(self, name: TaskName) -> Type[BaseExecutorTask]:
return self._tasks[name]

def get_task_meta(self, name: TaskName) -> Type[BaseTaskMeta]:
Expand Down
28 changes: 13 additions & 15 deletions lib/dl_task_processor/dl_task_processor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Dict,
Iterable,
Optional,
Union,
Sequence,
)

from arq import Worker as _ArqWorker
Expand All @@ -31,7 +31,7 @@
class WorkerSettings:
# we should not allow forever-fail tasks because it can stop the whole system
# but (if you really want it) you can provide float('inf')
retry_hard_limit: Union[int, float] = attr.ib(default=100)
retry_hard_limit: int = attr.ib(default=100)
job_timeout: int = attr.ib(default=600) # seconds
health_check_interval: int = attr.ib(default=30)
health_check_suffix: str = attr.ib(default="bihealthcheck")
Expand All @@ -44,7 +44,7 @@ class ArqWorker:
_redis_settings: RedisSettings = attr.ib()
_worker_settings: WorkerSettings = attr.ib()
_arq_worker: _ArqWorker = attr.ib(default=None)
_cron_tasks: Iterable[CronTask] = attr.ib(default=[])
_cron_tasks: Sequence[CronTask] = attr.ib(default=[])

@property
def health_check_key(self) -> str:
Expand All @@ -55,18 +55,16 @@ async def start(self) -> None:
self._arq_worker = _ArqWorker(
# let's trick strange typing in arq
# everybody does it o_O
**{
"functions": [arq_base_task],
"on_startup": self.start_executor,
"on_shutdown": self.stop_executor,
"max_tries": self._worker_settings.retry_hard_limit,
"job_timeout": timedelta(seconds=self._worker_settings.job_timeout),
"retry_jobs": True,
"health_check_key": self.health_check_key,
"handle_signals": False,
"health_check_interval": self._worker_settings.health_check_interval,
"cron_jobs": self._cron_tasks,
},
functions=[arq_base_task], # type: ignore[list-item]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we ignore it? is there some typing issue with arq?

on_startup=self.start_executor,
on_shutdown=self.stop_executor,
max_tries=self._worker_settings.retry_hard_limit,
job_timeout=timedelta(seconds=self._worker_settings.job_timeout),
retry_jobs=True,
health_check_key=self.health_check_key,
handle_signals=False,
health_check_interval=self._worker_settings.health_check_interval,
cron_jobs=self._cron_tasks,
redis_pool=redis_pool,
)
await self._arq_worker.main()
Expand Down
Loading