diff --git a/lib/dl_file_uploader_api_lib/dl_file_uploader_api_lib/schemas/files.py b/lib/dl_file_uploader_api_lib/dl_file_uploader_api_lib/schemas/files.py index 7958dcd53..58d88cf1f 100644 --- a/lib/dl_file_uploader_api_lib/dl_file_uploader_api_lib/schemas/files.py +++ b/lib/dl_file_uploader_api_lib/dl_file_uploader_api_lib/schemas/files.py @@ -34,7 +34,7 @@ def validate_authorized_yadocs(data: dict) -> None: ) -def validate_yadocs_data(data): +def validate_yadocs_data(data: dict) -> None: if not ((data["public_link"] is None) ^ (data["private_path"] is None)): raise ValueError("Expected exactly one of [`private_path`, `public_link`] to be specified") if data["public_link"] is None and data["oauth_token"] is None and data["connection_id"] is None: @@ -110,11 +110,11 @@ def get_obj_type(self, obj: dict[str, Any]) -> str: assert isinstance(type_field, FileType) return type_field.name - def get_data_type(self, data): - data_type = data.get(self.type_field) + def get_data_type(self, data: dict[str, Any]) -> str: + default_data_type = FileType.gsheets.value + data_type = data.get(self.type_field, default_data_type) if self.type_field not in data: - data[self.type_field] = FileType.gsheets.value - data_type = FileType.gsheets.value + data[self.type_field] = default_data_type if self.type_field in data and self.type_field_remove: data.pop(self.type_field) return data_type diff --git a/lib/dl_file_uploader_api_lib/dl_file_uploader_api_lib/views/files.py b/lib/dl_file_uploader_api_lib/dl_file_uploader_api_lib/views/files.py index 0eb895d0f..b9f20723f 100644 --- a/lib/dl_file_uploader_api_lib/dl_file_uploader_api_lib/views/files.py +++ b/lib/dl_file_uploader_api_lib/dl_file_uploader_api_lib/views/files.py @@ -172,12 +172,6 @@ async def post(self) -> web.StreamResponse: class DocumentsView(FileUploaderBaseView): REQUIRED_RESOURCES: ClassVar[frozenset[RequiredResource]] = frozenset() # Don't skip CSRF check - FILE_TYPE_TO_DATA_FILE_PREPARER_MAP: dict[ - FileType, Callable[[str, RedisModelManager, Optional[str]], Awaitable[DataFile]] - ] = { - FileType.yadocs: yadocs_data_file_preparer, - } - async def post(self) -> web.StreamResponse: req_data = await self._load_post_request_schema_data(files_schemas.FileDocumentsRequestSchema) @@ -191,7 +185,7 @@ async def post(self) -> web.StreamResponse: connection_id: Optional[str] = req_data["connection_id"] authorized: bool = req_data["authorized"] - df = await self.FILE_TYPE_TO_DATA_FILE_PREPARER_MAP[file_type]( + df = await yadocs_data_file_preparer( oauth_token=oauth_token, private_path=private_path, public_link=public_link, diff --git a/lib/dl_file_uploader_task_interface/dl_file_uploader_task_interface/tasks.py b/lib/dl_file_uploader_task_interface/dl_file_uploader_task_interface/tasks.py index 7bbd28a4c..9aad088d2 100644 --- a/lib/dl_file_uploader_task_interface/dl_file_uploader_task_interface/tasks.py +++ b/lib/dl_file_uploader_task_interface/dl_file_uploader_task_interface/tasks.py @@ -59,7 +59,7 @@ class ProcessExcelTask(BaseTaskMeta): name = TaskName("process_excel") file_id: str = attr.ib() - exec_mode: Optional[TaskExecutionMode] = attr.ib(default=TaskExecutionMode.BASIC) + exec_mode: TaskExecutionMode = attr.ib(default=TaskExecutionMode.BASIC) tenant_id: Optional[str] = attr.ib(default=None) connection_id: Optional[str] = attr.ib(default=None) diff --git a/lib/dl_file_uploader_worker_lib/dl_file_uploader_worker_lib/tasks/cleanup.py b/lib/dl_file_uploader_worker_lib/dl_file_uploader_worker_lib/tasks/cleanup.py index ddbd99c17..1b67ea4ab 100644 --- a/lib/dl_file_uploader_worker_lib/dl_file_uploader_worker_lib/tasks/cleanup.py +++ b/lib/dl_file_uploader_worker_lib/dl_file_uploader_worker_lib/tasks/cleanup.py @@ -273,6 +273,8 @@ async def run(self) -> TaskResult: old_s3_filename = source.s3_filename else: assert old_tenant_id + assert conn.uuid + assert source.s3_filename_suffix old_s3_filename = "_".join( (old_tenant_id, conn.uuid, source.s3_filename_suffix) ) @@ -282,6 +284,7 @@ async def run(self) -> TaskResult: if len(old_fname_parts) >= 2 and all(part for part in old_fname_parts): # assume that first part is old tenant id old_tenant_id = old_fname_parts[0] + assert old_tenant_id old_tenants.add(old_tenant_id) s3_filename_suffix = ( diff --git a/lib/dl_file_uploader_worker_lib/dl_file_uploader_worker_lib/tasks/download_yadocs.py b/lib/dl_file_uploader_worker_lib/dl_file_uploader_worker_lib/tasks/download_yadocs.py index 140872857..3667bbc05 100644 --- a/lib/dl_file_uploader_worker_lib/dl_file_uploader_worker_lib/tasks/download_yadocs.py +++ b/lib/dl_file_uploader_worker_lib/dl_file_uploader_worker_lib/tasks/download_yadocs.py @@ -119,7 +119,7 @@ async def run(self) -> TaskResult: download_error = FileProcessingError.from_exception(e) dfile.status = FileProcessingStatus.failed dfile.error = download_error - if self.meta.exec_mode != TaskExecutionMode.BASIC: + if self.meta.exec_mode != TaskExecutionMode.BASIC and dfile.sources is not None: for src in dfile.sources: src.status = FileProcessingStatus.failed src.error = download_error diff --git a/lib/dl_file_uploader_worker_lib/dl_file_uploader_worker_lib/tasks/excel.py b/lib/dl_file_uploader_worker_lib/dl_file_uploader_worker_lib/tasks/excel.py index fbb6744c8..bd46ad4e5 100644 --- a/lib/dl_file_uploader_worker_lib/dl_file_uploader_worker_lib/tasks/excel.py +++ b/lib/dl_file_uploader_worker_lib/dl_file_uploader_worker_lib/tasks/excel.py @@ -25,6 +25,7 @@ DataSource, ExcelFileSourceSettings, FileProcessingError, + FileSourceSettings, YaDocsFileSourceSettings, YaDocsUserSourceDataSourceProperties, ) @@ -53,7 +54,7 @@ class ProcessExcelTask(BaseExecutorTask[task_interface.ProcessExcelTask, FileUpl async def run(self) -> TaskResult: dfile: Optional[DataFile] = None - sources_to_update_by_sheet_id: dict[int, list[DataSource]] = defaultdict(list) + sources_to_update_by_sheet_id: dict[str, list[DataSource]] = defaultdict(list) usm = self._ctx.get_async_usm() task_processor = self._ctx.make_task_processor(self._request_id) redis = self._ctx.redis_service.get_redis() @@ -104,6 +105,9 @@ async def run(self) -> TaskResult: file_data = await resp.json() for src in dfile.sources: + # sources can be pre-filled only during update, which can happen only to document at the moment + # TODO needs refactoring + assert isinstance(src.user_source_dsrc_properties, YaDocsUserSourceDataSourceProperties) sources_to_update_by_sheet_id[src.user_source_dsrc_properties.sheet_id].append(src) for spreadsheet in file_data: @@ -158,7 +162,7 @@ async def run(self) -> TaskResult: exc_to_save = ex if isinstance(ex, exc.DLFileUploaderBaseError) else exc.ParseFailed() src.error = FileProcessingError.from_exception(exc_to_save) connection_error_tracker.add_error(src.id, src.error) - sheet_settings = None + sheet_settings: Optional[FileSourceSettings] = None for src in sheet_data_sources: if src.is_applicable: diff --git a/lib/dl_task_processor/dl_task_processor/arq_wrapper.py b/lib/dl_task_processor/dl_task_processor/arq_wrapper.py index daa55f688..63352014d 100644 --- a/lib/dl_task_processor/dl_task_processor/arq_wrapper.py +++ b/lib/dl_task_processor/dl_task_processor/arq_wrapper.py @@ -44,8 +44,8 @@ class ArqCronWrapper: _task: BaseTaskMeta = attr.ib() __qualname__ = "ArqCronWrapper" - # because of asyncio.iscoroutinefunction in the arq core - _is_coroutine = asyncio.coroutines._is_coroutine + # special asyncio marker; because of asyncio.iscoroutinefunction in the arq core + _is_coroutine = asyncio.coroutines._is_coroutine # type: ignore async def __call__(self, ctx: Dict[Any, Any], *args: Any, **kwargs: Any) -> Any: # pragma: no cover return await arq_base_task( @@ -86,7 +86,7 @@ def create_arq_redis_settings(settings: _BIRedisSettings) -> ArqRedisSettings: port=settings.PORT, password=settings.PASSWORD, database=settings.DB, - ssl=settings.SSL, + ssl=settings.SSL or False, ) elif settings.MODE == RedisMode.sentinel: redis_targets = [(host, settings.PORT) for host in settings.HOSTS] @@ -96,7 +96,7 @@ def create_arq_redis_settings(settings: _BIRedisSettings) -> ArqRedisSettings: sentinel=True, sentinel_master=settings.CLUSTER_NAME, database=settings.DB, - ssl=settings.SSL, + ssl=settings.SSL or False, ) else: raise ValueError(f"Unknown redis mode {settings.MODE}") @@ -132,7 +132,6 @@ def make_cron_task(task: BaseTaskMeta, schedule: CronSchedule) -> CronTask: async def arq_base_task(context: Dict, params: Dict) -> None: LOGGER.info("Run arq task with params %s", params) - context[EXECUTOR_KEY]: Executor # transition # i'll delete it later if "task_params" in params: @@ -156,7 +155,8 @@ async def arq_base_task(context: Dict, params: Dict) -> None: instance_id=instance_id, attempt=context["job_try"] - 1, # it starts from 1 o_O ) - job_result = await context[EXECUTOR_KEY].run_job(task_instance) + executor: Executor = context[EXECUTOR_KEY] + job_result = await executor.run_job(task_instance) if isinstance(job_result, Retry): # https://arq-docs.helpmanual.io/#retrying-jobs-and-cancellation raise ArqRetry( diff --git a/lib/dl_task_processor/dl_task_processor/state.py b/lib/dl_task_processor/dl_task_processor/state.py index f29919b28..3fbbc374d 100644 --- a/lib/dl_task_processor/dl_task_processor/state.py +++ b/lib/dl_task_processor/dl_task_processor/state.py @@ -26,7 +26,7 @@ def set_state(self, task: TaskInstance, state: str) -> None: pass def get_state(self, task: TaskInstance) -> list: - pass + return [] # i will change it later @@ -75,7 +75,7 @@ async def wait_task(task: TaskInstance, state: TaskState, timeout: float = 10, i """ timeout == the bottom line """ - spent_time = 0 + spent_time = 0.0 while spent_time < timeout: current_state = state.get_state(task) # Has the task reached the final state? diff --git a/lib/dl_task_processor/dl_task_processor/task.py b/lib/dl_task_processor/dl_task_processor/task.py index 9fea1ba4f..64e31c850 100644 --- a/lib/dl_task_processor/dl_task_processor/task.py +++ b/lib/dl_task_processor/dl_task_processor/task.py @@ -2,6 +2,7 @@ from collections.abc import Iterable import enum from typing import ( + Any, ClassVar, Generic, NewType, @@ -56,13 +57,10 @@ class TaskInstance: request_id: Optional[str] = attr.ib(default=None) +@attr.s class BaseTaskMeta(metaclass=abc.ABCMeta): name: ClassVar[TaskName] - def __init__(self, *args, **kwargs) -> None: - # lets trick typing - pass - def get_params(self, with_name: bool = False) -> dict: if with_name: return dict( @@ -119,7 +117,7 @@ class Retry(TaskResult): @attr.s class BaseExecutorTask(Generic[_BASE_TASK_META_TV, _BASE_TASK_CONTEXT_TV], metaclass=abc.ABCMeta): - cls_meta: ClassVar[Type[_BASE_TASK_META_TV]] + cls_meta: ClassVar[Type[_BASE_TASK_META_TV]] # type: ignore[misc] # TODO: fix meta: _BASE_TASK_META_TV = attr.ib() _ctx: _BASE_TASK_CONTEXT_TV = attr.ib() _instance_id: InstanceID = attr.ib() @@ -154,7 +152,7 @@ async def run(self) -> TaskResult: @attr.s class TaskRegistry: - _tasks: dict[TaskName, BaseExecutorTask] = attr.ib() + _tasks: dict[TaskName, Type[BaseExecutorTask]] = attr.ib() @classmethod def create(cls, tasks: Iterable[Type[BaseExecutorTask]]) -> "TaskRegistry": @@ -163,7 +161,7 @@ def create(cls, tasks: Iterable[Type[BaseExecutorTask]]) -> "TaskRegistry": ) == sorted(list(set([t.name() for t in tasks]))), "Some tasks has the same name" return cls(tasks={task.name(): task for task in tasks}) - def get_task(self, name: TaskName) -> BaseExecutorTask: + def get_task(self, name: TaskName) -> Type[BaseExecutorTask]: return self._tasks[name] def get_task_meta(self, name: TaskName) -> Type[BaseTaskMeta]: diff --git a/lib/dl_task_processor/dl_task_processor/worker.py b/lib/dl_task_processor/dl_task_processor/worker.py index 25d2e34b2..1fe1e15dd 100644 --- a/lib/dl_task_processor/dl_task_processor/worker.py +++ b/lib/dl_task_processor/dl_task_processor/worker.py @@ -6,7 +6,7 @@ Dict, Iterable, Optional, - Union, + Sequence, ) from arq import Worker as _ArqWorker @@ -31,7 +31,7 @@ class WorkerSettings: # we should not allow forever-fail tasks because it can stop the whole system # but (if you really want it) you can provide float('inf') - retry_hard_limit: Union[int, float] = attr.ib(default=100) + retry_hard_limit: int = attr.ib(default=100) job_timeout: int = attr.ib(default=600) # seconds health_check_interval: int = attr.ib(default=30) health_check_suffix: str = attr.ib(default="bihealthcheck") @@ -44,7 +44,7 @@ class ArqWorker: _redis_settings: RedisSettings = attr.ib() _worker_settings: WorkerSettings = attr.ib() _arq_worker: _ArqWorker = attr.ib(default=None) - _cron_tasks: Iterable[CronTask] = attr.ib(default=[]) + _cron_tasks: Sequence[CronTask] = attr.ib(default=[]) @property def health_check_key(self) -> str: @@ -55,18 +55,16 @@ async def start(self) -> None: self._arq_worker = _ArqWorker( # let's trick strange typing in arq # everybody does it o_O - **{ - "functions": [arq_base_task], - "on_startup": self.start_executor, - "on_shutdown": self.stop_executor, - "max_tries": self._worker_settings.retry_hard_limit, - "job_timeout": timedelta(seconds=self._worker_settings.job_timeout), - "retry_jobs": True, - "health_check_key": self.health_check_key, - "handle_signals": False, - "health_check_interval": self._worker_settings.health_check_interval, - "cron_jobs": self._cron_tasks, - }, + functions=[arq_base_task], # type: ignore[list-item] + on_startup=self.start_executor, + on_shutdown=self.stop_executor, + max_tries=self._worker_settings.retry_hard_limit, + job_timeout=timedelta(seconds=self._worker_settings.job_timeout), + retry_jobs=True, + health_check_key=self.health_check_key, + handle_signals=False, + health_check_interval=self._worker_settings.health_check_interval, + cron_jobs=self._cron_tasks, redis_pool=redis_pool, ) await self._arq_worker.main()