From d0b7d4138f6af637441523cf1092c02c51c39abf Mon Sep 17 00:00:00 2001 From: Marcos Prieto Date: Tue, 25 Jun 2024 10:29:08 +0200 Subject: [PATCH] Misc typing fixes in not annotated functions Quick fixes with the goal of enabling mypy's check_untyped_defs --- lms/config.py | 2 +- lms/db/_columns.py | 5 +++-- lms/extensions/feature_flags/__init__.py | 2 +- lms/models/application_instance.py | 10 ++++------ lms/models/assignment.py | 2 +- lms/models/group_info.py | 6 ++---- lms/models/grouping.py | 2 +- lms/models/lti_role.py | 15 +++++++-------- lms/product/d2l/_plugin/course_copy.py | 2 +- lms/resources/_js_config/__init__.py | 4 ++-- lms/services/canvas.py | 3 ++- lms/services/canvas_api/_basic.py | 12 ++++++------ lms/services/canvas_api/client.py | 4 ++-- lms/services/course.py | 2 +- lms/services/d2l_api/client.py | 2 +- lms/services/organization.py | 2 +- lms/validation/_lti_launch_params.py | 2 +- lms/views/admin/_schemas.py | 2 +- lms/views/admin/application_instance/search.py | 4 +--- lms/views/admin/course.py | 1 + lms/views/api/blackboard/files.py | 2 +- 21 files changed, 41 insertions(+), 45 deletions(-) diff --git a/lms/config.py b/lms/config.py index 89def79784..9fcfda05da 100644 --- a/lms/config.py +++ b/lms/config.py @@ -39,7 +39,7 @@ class _Setting: """The properties of a setting and how to read it.""" name: str - read_from: str | None = None + read_from: str = None # type: ignore value_mapper: Callable | None = None def __post_init__(self): diff --git a/lms/db/_columns.py b/lms/db/_columns.py index d9d8bc6515..daa054b1a0 100644 --- a/lms/db/_columns.py +++ b/lms/db/_columns.py @@ -1,4 +1,5 @@ import sqlalchemy as sa +from sqlalchemy.orm import Mapped, mapped_column def varchar_enum( # noqa: PLR0913, PLR0917 @@ -8,9 +9,9 @@ def varchar_enum( # noqa: PLR0913, PLR0917 nullable=False, server_default=None, unique=False, -) -> sa.Column: +) -> Mapped: """Return a SA column type to store the python enum.Enum as a varchar in a table.""" - return sa.Column( + return mapped_column( sa.Enum( enum, # In order to maintain maximum flexibility we will only enforce the diff --git a/lms/extensions/feature_flags/__init__.py b/lms/extensions/feature_flags/__init__.py index 2e1604d5ba..d123ab01a5 100644 --- a/lms/extensions/feature_flags/__init__.py +++ b/lms/extensions/feature_flags/__init__.py @@ -176,7 +176,7 @@ def feature(request, feature_flag_name): def add_feature_flag_providers(_config, *providers): """Adapt feature_flags.add_providers().""" - providers = [config.maybe_dotted(provider) for provider in providers] + providers = [config.maybe_dotted(provider) for provider in providers] # type:ignore return feature_flags.add_providers(*providers) # Register the Pyramid request method and config directive. These are this diff --git a/lms/models/application_instance.py b/lms/models/application_instance.py index d897b276e2..bb309b2631 100644 --- a/lms/models/application_instance.py +++ b/lms/models/application_instance.py @@ -92,7 +92,7 @@ class ApplicationInstance(CreatedUpdatedMixin, Base): consumer_key = sa.Column(sa.Unicode, unique=True, nullable=True) shared_secret = sa.Column(sa.Unicode, nullable=False) - lms_url = sa.Column(sa.Unicode(2048), nullable=False) + lms_url: Mapped[str] = mapped_column(sa.Unicode(2048), nullable=False) requesters_email = sa.Column(sa.Unicode(2048), nullable=False) last_launched = sa.Column(sa.DateTime(), nullable=True) @@ -165,10 +165,8 @@ class ApplicationInstance(CreatedUpdatedMixin, Base): files = sa.orm.relationship("File", back_populates="application_instance") # LTIRegistration this instance belong to - lti_registration_id = sa.Column( - sa.Integer(), - sa.ForeignKey("lti_registration.id", ondelete="cascade"), - nullable=True, + lti_registration_id: Mapped[int | None] = mapped_column( + sa.ForeignKey("lti_registration.id", ondelete="cascade") ) lti_registration = sa.orm.relationship( @@ -176,7 +174,7 @@ class ApplicationInstance(CreatedUpdatedMixin, Base): ) # Unique identifier of this instance per LTIRegistration - deployment_id = sa.Column(sa.UnicodeText, nullable=True) + deployment_id: Mapped[str | None] = mapped_column() role_overrides = sa.orm.relationship( "LTIRoleOverride", back_populates="application_instance" diff --git a/lms/models/assignment.py b/lms/models/assignment.py index 43470c02bd..2992bfdc63 100644 --- a/lms/models/assignment.py +++ b/lms/models/assignment.py @@ -29,7 +29,7 @@ class Assignment(CreatedUpdatedMixin, Base): sa.UniqueConstraint("resource_link_id", "tool_consumer_instance_guid"), ) - id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + id: Mapped[int] = mapped_column(autoincrement=True, primary_key=True) resource_link_id = sa.Column(sa.Unicode, nullable=False) """The resource_link_id launch param of the assignment.""" diff --git a/lms/models/group_info.py b/lms/models/group_info.py index 990ec3e27d..85be916c43 100644 --- a/lms/models/group_info.py +++ b/lms/models/group_info.py @@ -88,12 +88,10 @@ class GroupInfo(Base): custom_canvas_course_id = sa.Column(sa.UnicodeText()) #: A dict of info about this group. - _info: Mapped[MutableDict | None] = mapped_column( - "info", MutableDict.as_mutable(JSONB()) - ) + _info: Mapped[dict | None] = mapped_column("info", MutableDict.as_mutable(JSONB())) @property - def _safe_info(self): + def _safe_info(self) -> dict: if self._info is None: self._info = {} diff --git a/lms/models/grouping.py b/lms/models/grouping.py index 2467ae7df7..58d7f90624 100644 --- a/lms/models/grouping.py +++ b/lms/models/grouping.py @@ -124,7 +124,7 @@ class Type(str, Enum): nullable=False, ) - extra: Mapped[MutableDict] = mapped_column( + extra: Mapped[dict] = mapped_column( MutableDict.as_mutable(JSONB()), server_default=sa.text("'{}'::jsonb"), nullable=False, diff --git a/lms/models/lti_role.py b/lms/models/lti_role.py index 5442bb4f16..ca197267bc 100644 --- a/lms/models/lti_role.py +++ b/lms/models/lti_role.py @@ -5,6 +5,7 @@ import sqlalchemy as sa from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import Mapped, mapped_column, relationship from lms.db import Base, varchar_enum @@ -39,27 +40,25 @@ class LTIRoleOverride(Base): id = sa.Column(sa.Integer(), autoincrement=True, primary_key=True) - lti_role_id = sa.Column( - sa.Integer(), + lti_role_id: Mapped[int] = mapped_column( sa.ForeignKey("lti_role.id", ondelete="cascade"), nullable=True, index=True, ) - lti_role = sa.orm.relationship("LTIRole") + lti_role = relationship("LTIRole") - application_instance_id = sa.Column( - sa.Integer(), + application_instance_id: Mapped[int] = mapped_column( sa.ForeignKey("application_instances.id", ondelete="cascade"), nullable=False, ) - application_instance = sa.orm.relationship( + application_instance = relationship( "ApplicationInstance", back_populates="role_overrides" ) - type = varchar_enum(RoleType) + type: Mapped[RoleScope] = varchar_enum(RoleType) """Our interpretation of the value.""" - scope = varchar_enum(RoleScope, nullable=True) + scope: Mapped[RoleType] = varchar_enum(RoleScope, nullable=True) """Scope where this role applies""" @property diff --git a/lms/product/d2l/_plugin/course_copy.py b/lms/product/d2l/_plugin/course_copy.py index d4c0555077..12db081010 100644 --- a/lms/product/d2l/_plugin/course_copy.py +++ b/lms/product/d2l/_plugin/course_copy.py @@ -5,7 +5,7 @@ class D2LCourseCopyPlugin: """Handle course copy for D2L.""" - file_type = "d2l_file" + file_type: str = "d2l_file" def __init__( self, diff --git a/lms/resources/_js_config/__init__.py b/lms/resources/_js_config/__init__.py index c434223d39..5d8f78f476 100644 --- a/lms/resources/_js_config/__init__.py +++ b/lms/resources/_js_config/__init__.py @@ -577,7 +577,7 @@ def _config(self): return config - def _get_product_info(self): + def _get_product_info(self) -> dict: """Return product (Canvas, BB, D2L..) configuration.""" product = self._request.product @@ -591,7 +591,7 @@ def _get_product_info(self): } if self._request.product.settings.groups_enabled: - product_info["api"]["listGroupSets"] = { + product_info["api"]["listGroupSets"] = { # type: ignore "authUrl": ( self._request.route_url(product.route.oauth2_authorize) if product.route.oauth2_authorize diff --git a/lms/services/canvas.py b/lms/services/canvas.py index 6022c11600..718e1e3fda 100644 --- a/lms/services/canvas.py +++ b/lms/services/canvas.py @@ -1,10 +1,11 @@ +from lms.services.canvas_api.client import CanvasAPIClient from lms.services.exceptions import CanvasAPIPermissionError, FileNotFoundInCourse class CanvasService: """A high level Canvas service.""" - api = None + api: CanvasAPIClient = None # type:ignore def __init__(self, canvas_api, course_copy_plugin): self.api = canvas_api diff --git a/lms/services/canvas_api/_basic.py b/lms/services/canvas_api/_basic.py index 68f334712e..9731779543 100644 --- a/lms/services/canvas_api/_basic.py +++ b/lms/services/canvas_api/_basic.py @@ -4,9 +4,9 @@ from urllib.parse import urlencode import requests -from requests import RequestException, Session +from requests import RequestException, Response, Session -from lms.services import CanvasAPIError, ExternalRequestError +from lms.services.exceptions import CanvasAPIError, ExternalRequestError class BasicClient: @@ -50,7 +50,7 @@ def send( # noqa: PLR0913, PLR0917 params=None, headers=None, url_stub="/api/v1", - ): + ) -> list: """ Make a request to the Canvas API and apply a schema to the response. @@ -97,8 +97,8 @@ def _get_url(self, path, params, url_stub): "?" + urlencode(params) if params else "" ) - def _send_prepared(self, request, schema, timeout, request_depth=1): - response = None + def _send_prepared(self, request, schema, timeout, request_depth=1) -> list: + response: Response = None # type:ignore try: response = self._session.send(request, timeout=timeout) @@ -106,7 +106,7 @@ def _send_prepared(self, request, schema, timeout, request_depth=1): except RequestException as err: CanvasAPIError.raise_from(err, request, response) - result = None + result: list = None # type: ignore try: result = schema(response).parse() except ExternalRequestError as err: diff --git a/lms/services/canvas_api/client.py b/lms/services/canvas_api/client.py index 84d245d0fb..1b7ce84dad 100644 --- a/lms/services/canvas_api/client.py +++ b/lms/services/canvas_api/client.py @@ -516,7 +516,7 @@ class Meta: course_id = fields.Integer(required=True) @classmethod - def _ensure_sections_unique(cls, sections): + def _ensure_sections_unique(cls, sections) -> list: """ Ensure that sections returned by Canvas are unique. @@ -528,7 +528,7 @@ def _ensure_sections_unique(cls, sections): :return: A list of unique sections :raise CanvasAPIError: When duplicate sections have different names """ - sections_by_id = {} + sections_by_id: dict[int, dict] = {} for section in sections: duplicate = sections_by_id.get(section["id"]) diff --git a/lms/services/course.py b/lms/services/course.py index 026a07fd10..4a5ba601e1 100644 --- a/lms/services/course.py +++ b/lms/services/course.py @@ -49,7 +49,7 @@ def any_with_setting(self, group, key, value=True) -> bool: .count() ) - def get_from_launch(self, product, lti_params): + def get_from_launch(self, product, lti_params) -> Course: """Get the course this LTI launch based on the request's params.""" historical_course = None diff --git a/lms/services/d2l_api/client.py b/lms/services/d2l_api/client.py index 1094a4197d..d1beee5c7c 100644 --- a/lms/services/d2l_api/client.py +++ b/lms/services/d2l_api/client.py @@ -121,7 +121,7 @@ def group_set_groups(self, org_unit, group_category_id, user_id=None): return groups - def list_files(self, org_unit) -> list[dict]: + def list_files(self, org_unit: str) -> list[dict]: """Get a nested list of files and folders for the given `org_unit`.""" modules = self._get_course_modules(org_unit) files = list(self._find_files(org_unit, modules)) diff --git a/lms/services/organization.py b/lms/services/organization.py index 756231b833..ed80e53920 100644 --- a/lms/services/organization.py +++ b/lms/services/organization.py @@ -43,7 +43,7 @@ def get_by_id(self, id_: int) -> Organization | None: return self._organization_search_query(id_=id_).one_or_none() - def get_by_public_id(self, public_id: str) -> list | None: + def get_by_public_id(self, public_id: str) -> Organization | None: """ Get an organization by its public_id. diff --git a/lms/validation/_lti_launch_params.py b/lms/validation/_lti_launch_params.py index dc08615689..bed93bb9d1 100644 --- a/lms/validation/_lti_launch_params.py +++ b/lms/validation/_lti_launch_params.py @@ -132,7 +132,7 @@ def handle_error(self, error, data, *, many, **kwargs): # ``err.messages``, but without overwriting any of the existing # error messages already present in ``messages``. for field in err.messages: - messages.setdefault(field, []).extend(err.messages[field]) + messages.setdefault(field, []).extend(err.messages[field]) # type:ignore return_url = None if return_url: diff --git a/lms/views/admin/_schemas.py b/lms/views/admin/_schemas.py index 7b387d2440..77615d548c 100644 --- a/lms/views/admin/_schemas.py +++ b/lms/views/admin/_schemas.py @@ -19,7 +19,7 @@ def deserialize(self, value, attr, data, **kwargs): # pylint:disable=compare-to-empty-string if value == missing or value.strip() == "": return None - return super().deserialize(value, attr, data, **kwargs) + return super().deserialize(value, attr, data, **kwargs) # type:ignore class EmptyStringInt(EmptyStringNoneMixin, fields.Int): # type: ignore diff --git a/lms/views/admin/application_instance/search.py b/lms/views/admin/application_instance/search.py index 73dd5fe5cb..1df154c80c 100644 --- a/lms/views/admin/application_instance/search.py +++ b/lms/views/admin/application_instance/search.py @@ -52,9 +52,7 @@ def search_callback(self): settings = None if settings_key := self.request.params.get("settings_key"): if settings_value := self.request.params.get("settings_value"): - settings_value = SETTINGS_BY_FIELD.get(settings_key).format( - settings_value - ) + settings_value = SETTINGS_BY_FIELD[settings_key].format(settings_value) else: settings_value = ... diff --git a/lms/views/admin/course.py b/lms/views/admin/course.py index 385c44feb6..11fde4bc2a 100644 --- a/lms/views/admin/course.py +++ b/lms/views/admin/course.py @@ -103,6 +103,7 @@ def search(self): if org_public_id := self.request.params.get("org_public_id", "").strip(): try: organization = self.organization_service.get_by_public_id(org_public_id) + assert organization organization_ids = self.organization_service.get_hierarchy_ids( organization.id, include_parents=False ) diff --git a/lms/views/api/blackboard/files.py b/lms/views/api/blackboard/files.py index 2eca388927..e91b4698e8 100644 --- a/lms/views/api/blackboard/files.py +++ b/lms/views/api/blackboard/files.py @@ -76,7 +76,7 @@ def via_url(self): document_url = self.request.params["document_url"] file_id = course.get_mapped_file_id( - DOCUMENT_URL_REGEX.search(document_url)["file_id"] + DOCUMENT_URL_REGEX.search(document_url)["file_id"] # type: ignore ) try: if self.request.lti_user.is_instructor: