diff --git a/src/metakb/load_data.py b/src/metakb/load_data.py index 60327b79..0b4919d0 100644 --- a/src/metakb/load_data.py +++ b/src/metakb/load_data.py @@ -61,14 +61,14 @@ def _add_mappings_and_exts_to_obj(obj: dict, obj_keys: list[str]) -> None: obj_keys.append(f"{name}:${name}") -def _add_method(tx: ManagedTransaction, method: dict, ids_in_studies: set[str]) -> None: +def _add_method(tx: ManagedTransaction, method: dict, ids_in_stmts: set[str]) -> None: """Add Method node and its relationships to DB :param tx: Transaction object provided to transaction functions :param method: CDM method object - :param ids_in_studies: IDs found in studies + :param ids_in_stmts: IDs found in statements """ - if method["id"] not in ids_in_studies: + if method["id"] not in ids_in_stmts: return query = """ @@ -80,7 +80,7 @@ def _add_method(tx: ManagedTransaction, method: dict, ids_in_studies: set[str]) # Method's documents are unique and do not currently have IDs # They also only have one document document = is_reported_in[0] - _add_document(tx, document, ids_in_studies) + _add_document(tx, document, ids_in_stmts) doc_doi = document["doi"] query += f""" MERGE (d:Document {{ doi:'{doc_doi}' }}) @@ -91,16 +91,16 @@ def _add_method(tx: ManagedTransaction, method: dict, ids_in_studies: set[str]) def _add_gene_or_disease( - tx: ManagedTransaction, obj_in: dict, ids_in_studies: set[str] + tx: ManagedTransaction, obj_in: dict, ids_in_stmts: set[str] ) -> None: """Add gene or disease node and its relationships to DB :param tx: Transaction object provided to transaction functions :param obj_in: CDM gene or disease object - :param ids_in_studies: IDs found in studies + :param ids_in_stmts: IDs found in statements :raises TypeError: When `obj_in` is not a disease or gene """ - if obj_in["id"] not in ids_in_studies: + if obj_in["id"] not in ids_in_stmts: return obj = obj_in.copy() @@ -129,16 +129,16 @@ def _add_gene_or_disease( def _add_therapeutic_procedure( tx: ManagedTransaction, therapeutic_procedure: dict, - ids_in_studies: set[str], + ids_in_stmts: set[str], ) -> None: """Add therapeutic procedure node and its relationships :param tx: Transaction object provided to transaction functions :param therapeutic_procedure: Therapeutic procedure CDM object - :param ids_in_studies: IDs found in studies + :param ids_in_stmts: IDs found in statements :raises TypeError: When therapeutic procedure type is invalid """ - if therapeutic_procedure["id"] not in ids_in_studies: + if therapeutic_procedure["id"] not in ids_in_stmts: return tp = therapeutic_procedure.copy() @@ -264,15 +264,15 @@ def _add_variation(tx: ManagedTransaction, variation_in: dict) -> None: def _add_categorical_variant( tx: ManagedTransaction, categorical_variant_in: dict, - ids_in_studies: set[str], + ids_in_stmts: set[str], ) -> None: """Add categorical variant objects to DB. :param tx: Transaction object provided to transaction functions :param categorical_variant_in: Categorical variant CDM object - :param ids_in_studies: IDs found in studies + :param ids_in_stmts: IDs found in statements """ - if categorical_variant_in["id"] not in ids_in_studies: + if categorical_variant_in["id"] not in ids_in_stmts: return cv = categorical_variant_in.copy() @@ -311,19 +311,19 @@ def _add_categorical_variant( def _add_document( - tx: ManagedTransaction, document_in: dict, ids_in_studies: set[str] + tx: ManagedTransaction, document_in: dict, ids_in_stmts: set[str] ) -> None: """Add Document object to DB. :param tx: Transaction object provided to transaction functions :param document: Document CDM object - :param ids_in_studies: IDs found in studies + :param ids_in_stmts: IDs found in statements """ # Not all document's have IDs. These are the fields that can uniquely identify # a document if "id" in document_in: query = "MATCH (n:Document {id:$id}) RETURN n" - if document_in["id"] not in ids_in_studies: + if document_in["id"] not in ids_in_stmts: return elif "doi" in document_in: query = "MATCH (n:Document {doi:$doi}) RETURN n" @@ -351,81 +351,81 @@ def _add_document( tx.run(query, **document) -def _get_ids_from_studies(studies: list[dict]) -> set[str]: - """Get unique IDs from studies +def _get_ids_from_stmts(statements: list[dict]) -> set[str]: + """Get unique IDs from statements - :param studies: List of studies - :return: Set of IDs found in studies + :param statements: List of statements + :return: Set of IDs found in statements """ def _add_obj_id_to_set(obj: dict, ids_set: set[str]) -> None: """Add object id to set of IDs :param obj: Object to get ID for - :param ids_set: IDs found in studies. This will be mutated. + :param ids_set: IDs found in statements. This will be mutated. """ obj_id = obj.get("id") if obj_id: ids_set.add(obj_id) - ids_in_studies = set() + ids_in_stmts = set() - for study in studies: + for statement in statements: for obj in [ - study.get("specifiedBy"), # method - study.get("reportedIn"), - study.get("subjectVariant"), - study.get("objectTherapeutic"), - study.get("conditionQualifier"), - study.get("geneContextQualifier"), + statement.get("specifiedBy"), # method + statement.get("reportedIn"), + statement.get("subjectVariant"), + statement.get("objectTherapeutic"), + statement.get("conditionQualifier"), + statement.get("geneContextQualifier"), ]: if obj: if isinstance(obj, list): for item in obj: - _add_obj_id_to_set(item, ids_in_studies) + _add_obj_id_to_set(item, ids_in_stmts) else: # This is a dictionary - _add_obj_id_to_set(obj, ids_in_studies) + _add_obj_id_to_set(obj, ids_in_stmts) - return ids_in_studies + return ids_in_stmts -def _add_study(tx: ManagedTransaction, study_in: dict) -> None: - """Add study node and its relationships +def _add_statement(tx: ManagedTransaction, statement_in: dict) -> None: + """Add statement node and its relationships :param tx: Transaction object provided to transaction functions - :param study_in: Statement CDM object + :param statement_in: Statement CDM object """ - study = study_in.copy() - study_type = study["type"] - study_keys = _create_parameterized_query( - study, ("id", "description", "direction", "predicate", "type") + statement = statement_in.copy() + statement_type = statement["type"] + statement_keys = _create_parameterized_query( + statement, ("id", "description", "direction", "predicate", "type") ) match_line = "" rel_line = "" - is_reported_in_docs = study.get("reportedIn", []) + is_reported_in_docs = statement.get("reportedIn", []) for ri_doc in is_reported_in_docs: ri_doc_id = ri_doc["id"] name = f"doc_{ri_doc_id.split(':')[-1]}" match_line += f"MERGE ({name} {{ id: '{ri_doc_id}'}})\n" rel_line += f"MERGE (s) -[:IS_REPORTED_IN] -> ({name})\n" - allele_origin = study.get("alleleOriginQualifier") + allele_origin = statement.get("alleleOriginQualifier") if allele_origin: - study["alleleOriginQualifier"] = allele_origin + statement["alleleOriginQualifier"] = allele_origin match_line += "SET s.alleleOriginQualifier=$alleleOriginQualifier\n" - gene_context_id = study.get("geneContextQualifier", {}).get("id") + gene_context_id = statement.get("geneContextQualifier", {}).get("id") if gene_context_id: match_line += f"MERGE (g:Gene {{id: '{gene_context_id}'}})\n" rel_line += "MERGE (s) -[:HAS_GENE_CONTEXT] -> (g)\n" - method_id = study["specifiedBy"]["id"] + method_id = statement["specifiedBy"]["id"] match_line += f"MERGE (m {{ id: '{method_id}' }})\n" rel_line += "MERGE (s) -[:IS_SPECIFIED_BY] -> (m)\n" - coding = study.get("strength") + coding = statement.get("strength") if coding: coding_key_fields = ("code", "label", "system") @@ -435,75 +435,76 @@ def _add_study(tx: ManagedTransaction, study_in: dict) -> None: for k in coding_key_fields: v = coding.get(k) if v: - study[f"coding_{k}"] = v + statement[f"coding_{k}"] = v match_line += f"MERGE (c:Coding {{ {coding_keys} }})\n" rel_line += "MERGE (s) -[:HAS_STRENGTH] -> (c)\n" - variant_id = study["subjectVariant"]["id"] + variant_id = statement["subjectVariant"]["id"] match_line += f"MERGE (v:Variation {{ id: '{variant_id}' }})\n" rel_line += "MERGE (s) -[:HAS_VARIANT] -> (v)\n" - therapeutic_id = study["objectTherapeutic"]["id"] + therapeutic_id = statement["objectTherapeutic"]["id"] match_line += f"MERGE (t:TherapeuticProcedure {{ id: '{therapeutic_id}' }})\n" rel_line += "MERGE (s) -[:HAS_THERAPEUTIC] -> (t)\n" - tumor_type_id = study["conditionQualifier"]["id"] + tumor_type_id = statement["conditionQualifier"]["id"] match_line += f"MERGE (tt:Condition {{ id: '{tumor_type_id}' }})\n" rel_line += "MERGE (s) -[:HAS_TUMOR_TYPE] -> (tt)\n" query = f""" - MERGE (s:{study_type}:StudyStatement:Statement {{ {study_keys} }}) + MERGE (s:{statement_type}:StudyStatement:Statement {{ {statement_keys} }}) {match_line} {rel_line} """ - tx.run(query, **study) + tx.run(query, **statement) def add_transformed_data(driver: Driver, data: dict) -> None: """Add set of data formatted per Common Data Model to DB. :param data: contains key/value pairs for data objects to add to DB, including - studies, variation, therapeutic procedures, conditions, genes, methods, + statements, variation, therapeutic procedures, conditions, genes, methods, documents, etc. """ - # Used to keep track of IDs that are in studies. This is used to prevent adding - # nodes that aren't associated to studies - ids_in_studies = _get_ids_from_studies(data.get("studies", [])) + # Used to keep track of IDs that are in statements. This is used to prevent adding + # nodes that aren't associated to statements + statements = data.get("statements", []) + ids_in_stmts = _get_ids_from_stmts(statements) with driver.session() as session: - loaded_study_count = 0 + loaded_stmt_count = 0 for cv in data.get("categorical_variants", []): - session.execute_write(_add_categorical_variant, cv, ids_in_studies) + session.execute_write(_add_categorical_variant, cv, ids_in_stmts) for doc in data.get("documents", []): - session.execute_write(_add_document, doc, ids_in_studies) + session.execute_write(_add_document, doc, ids_in_stmts) for method in data.get("methods", []): - session.execute_write(_add_method, method, ids_in_studies) + session.execute_write(_add_method, method, ids_in_stmts) for obj_type in {"genes", "conditions"}: for obj in data.get(obj_type, []): - session.execute_write(_add_gene_or_disease, obj, ids_in_studies) + session.execute_write(_add_gene_or_disease, obj, ids_in_stmts) for tp in data.get("therapeutic_procedures", []): - session.execute_write(_add_therapeutic_procedure, tp, ids_in_studies) + session.execute_write(_add_therapeutic_procedure, tp, ids_in_stmts) # This should always be done last - for study in data.get("studies", []): - session.execute_write(_add_study, study) - loaded_study_count += 1 + for statement in statements: + session.execute_write(_add_statement, statement) + loaded_stmt_count += 1 - _logger.info("Successfully loaded %s studies.", loaded_study_count) + _logger.info("Successfully loaded %s statements.", loaded_stmt_count) def load_from_json(src_transformed_cdm: Path, driver: Driver | None = None) -> None: """Load evidence into DB from given CDM JSON file. :param src_transformed_cdm: path to file for a source's transformed data to - common data model containing studies, variation, therapeutic procedures, + common data model containing statements, variation, therapeutic procedures, conditions, genes, methods, documents, etc. :param driver: Neo4j graph driver, if available """ diff --git a/src/metakb/main.py b/src/metakb/main.py index f515bcc5..eda67880 100644 --- a/src/metakb/main.py +++ b/src/metakb/main.py @@ -11,10 +11,10 @@ from metakb.log_handle import configure_logs from metakb.query import PaginationParamError, QueryHandler from metakb.schemas.api import ( - BatchSearchStudiesQuery, - BatchSearchStudiesService, - SearchStudiesQuery, - SearchStudiesService, + BatchSearchStatementsQuery, + BatchSearchStatementsService, + SearchStatementsQuery, + SearchStatementsService, ServiceMeta, ) @@ -64,65 +64,65 @@ def custom_openapi() -> dict: app.openapi = custom_openapi -search_studies_summary = ( - "Get nested studies from queried concepts that match all conditions provided." +search_stmts_summary = ( + "Get nested statements from queried concepts that match all conditions provided." ) -search_studies_descr = ( - "Return nested studies that match the intersection of queried concepts. For " - "example, if `variation` and `therapy` are provided, will return all studies that " - "have both the provided `variation` and `therapy`." +search_stmts_descr = ( + "Return nested statements that match the intersection of queried concepts. For " + "example, if `variation` and `therapy` are provided, will return all statements " + "that have both the provided `variation` and `therapy`." ) v_description = "Variation (subject) to search. Can be free text or VRS Variation ID." d_description = "Disease (object qualifier) to search" t_description = "Therapy (object) to search" g_description = "Gene to search" -s_description = "Study ID to search." +s_description = "Statement ID to search." start_description = "The index of the first result to return. Use for pagination." limit_description = "The maximum number of results to return. Use for pagination." @app.get( - "/api/v2/search/studies", - summary=search_studies_summary, - response_model=SearchStudiesService, + "/api/v2/search/statements", + summary=search_stmts_summary, + response_model=SearchStatementsService, response_model_exclude_none=True, - description=search_studies_descr, + description=search_stmts_descr, ) -async def get_studies( +async def get_statements( variation: Annotated[str | None, Query(description=v_description)] = None, disease: Annotated[str | None, Query(description=d_description)] = None, therapy: Annotated[str | None, Query(description=t_description)] = None, gene: Annotated[str | None, Query(description=g_description)] = None, - study_id: Annotated[str | None, Query(description=s_description)] = None, + statement_id: Annotated[str | None, Query(description=s_description)] = None, start: Annotated[int, Query(description=start_description)] = 0, limit: Annotated[int | None, Query(description=limit_description)] = None, -) -> SearchStudiesService: - """Get nested studies from queried concepts that match all conditions provided. - For example, if `variation` and `therapy` are provided, will return all studies +) -> SearchStatementsService: + """Get nested statements from queried concepts that match all conditions provided. + For example, if `variation` and `therapy` are provided, will return all statements that have both the provided `variation` and `therapy`. :param variation: Variation query (Free text or VRS Variation ID) :param disease: Disease query :param therapy: Therapy query :param gene: Gene query - :param study_id: Study ID query. + :param statement_id: Statement ID query. :param start: The index of the first result to return. Use for pagination. :param limit: The maximum number of results to return. Use for pagination. - :return: SearchStudiesService response containing nested studies and service + :return: SearchStatementsService response containing nested statements and service metadata """ try: - resp = await query.search_studies( - variation, disease, therapy, gene, study_id, start, limit + resp = await query.search_statements( + variation, disease, therapy, gene, statement_id, start, limit ) except PaginationParamError: - resp = SearchStudiesService( - query=SearchStudiesQuery( + resp = SearchStatementsService( + query=SearchStatementsQuery( variation=variation, disease=disease, therapy=therapy, gene=gene, - study_id=study_id, + statement_id=statement_id, ), service_meta_=ServiceMeta(), warnings=["`start` and `limit` params must both be nonnegative"], @@ -131,8 +131,8 @@ async def get_studies( _batch_descr = { - "summary": "Get nested studies for all provided variations.", - "description": "Return nested studies associated with any of the provided variations.", + "summary": "Get nested statements for all provided variations.", + "description": "Return nested statements associated with any of the provided variations.", "arg_variations": "Variations (subject) to search. Can be free text or VRS variation ID.", "arg_start": "The index of the first result to return. Use for pagination.", "arg_limit": "The maximum number of results to return. Use for pagination.", @@ -140,21 +140,21 @@ async def get_studies( @app.get( - "/api/v2/batch_search/studies", + "/api/v2/batch_search/statements", summary=_batch_descr["summary"], - response_model=BatchSearchStudiesService, + response_model=BatchSearchStatementsService, response_model_exclude_none=True, description=_batch_descr["description"], ) -async def batch_get_studies( +async def batch_get_statements( variations: Annotated[ list[str] | None, Query(description=_batch_descr["arg_variations"]), ] = None, start: Annotated[int, Query(description=_batch_descr["arg_start"])] = 0, limit: Annotated[int | None, Query(description=_batch_descr["arg_limit"])] = None, -) -> BatchSearchStudiesService: - """Fetch all studies associated with `any` of the provided variations. +) -> BatchSearchStatementsService: + """Fetch all statements associated with `any` of the provided variations. :param variations: variations to match against :param start: The index of the first result to return. Use for pagination. @@ -162,10 +162,10 @@ async def batch_get_studies( :return: batch response object """ try: - response = await query.batch_search_studies(variations, start, limit) + response = await query.batch_search_statements(variations, start, limit) except PaginationParamError: - response = BatchSearchStudiesService( - query=BatchSearchStudiesQuery(variations=[]), + response = BatchSearchStatementsService( + query=BatchSearchStatementsQuery(variations=[]), service_meta_=ServiceMeta(), warnings=["`start` and `limit` params must both be nonnegative"], ) diff --git a/src/metakb/query.py b/src/metakb/query.py index cb66a2a6..7aab4df1 100644 --- a/src/metakb/query.py +++ b/src/metakb/query.py @@ -30,10 +30,10 @@ ViccNormalizers, ) from metakb.schemas.api import ( - BatchSearchStudiesQuery, - BatchSearchStudiesService, + BatchSearchStatementsQuery, + BatchSearchStatementsService, NormalizedQuery, - SearchStudiesService, + SearchStatementsService, ServiceMeta, ) from metakb.schemas.app import SourceName @@ -109,21 +109,21 @@ def __init__( ... ViccNormalizers("http://localhost:8000") ... ) - ``default_page_limit`` sets the default max number of studies to include in + ``default_page_limit`` sets the default max number of statements to include in query responses: >>> limited_qh = QueryHandler(default_page_limit=10) - >>> response = await limited_qh.batch_search_studies(["BRAF V600E"]) - >>> print(len(response.study_ids)) + >>> response = await limited_qh.batch_search_statements(["BRAF V600E"]) + >>> print(len(response.statement_ids)) 10 This value is overruled by an explicit ``limit`` parameter: - >>> response = await limited_qh.batch_search_studies( + >>> response = await limited_qh.batch_search_statements( ... ["BRAF V600E"], ... limit=2 ... ) - >>> print(len(response.study_ids)) + >>> print(len(response.statement_ids)) 2 :param driver: driver instance for graph connection @@ -139,26 +139,26 @@ def __init__( self.vicc_normalizers = normalizers self._default_page_limit = default_page_limit - async def search_studies( + async def search_statements( self, variation: str | None = None, disease: str | None = None, therapy: str | None = None, gene: str | None = None, - study_id: str | None = None, + statement_id: str | None = None, start: int = 0, limit: int | None = None, - ) -> SearchStudiesService: - """Get nested studies from queried concepts that match all conditions provided. - For example, if ``variation`` and ``therapy`` are provided, will return all studies - that have both the provided ``variation`` and ``therapy``. + ) -> SearchStatementsService: + """Get nested statements from queried concepts that match all conditions provided. + For example, if ``variation`` and ``therapy`` are provided, will return all + statements that have both the provided ``variation`` and ``therapy``. >>> from metakb.query import QueryHandler >>> qh = QueryHandler() - >>> result = qh.search_studies("BRAF V600E") - >>> result.study_ids[:3] + >>> result = qh.search_statements("BRAF V600E") + >>> result.statement_ids[:3] ['moa.assertion:944', 'moa.assertion:911', 'moa.assertion:865'] - >>> result.studies[0].reportedIn[0].urls[0] + >>> result.statements[0].reportedIn[0].urls[0] 'https://www.accessdata.fda.gov/drugsatfda_docs/label/2020/202429s019lbl.pdf' Variation, disease, therapy, and gene terms are resolved via their respective @@ -174,11 +174,12 @@ async def search_studies( ``"GLEEVEC"``, or concept URI, e.g. ``"chembl:CHEMBL941"``. Case-insensitive. :param gene: Gene query. Common shorthand name, e.g. ``"NTRK1"``, or compact URI, e.g. ``"ensembl:ENSG00000198400"``. - :param study_id: Study ID query provided by source, e.g. ``"civic.eid:3017"``. + :param statement_id: Statement ID query provided by source, e.g. ``"civic.eid:3017"``. :param start: Index of first result to fetch. Must be nonnegative. :param limit: Max number of results to fetch. Must be nonnegative. Revert to default defined at class initialization if not given. - :return: Service response object containing nested studies and service metadata. + :return: Service response object containing nested statements and service + metadata. """ if start < 0: msg = "Can't start from an index of less than 0." @@ -193,35 +194,35 @@ async def search_studies( "disease": None, "therapy": None, "gene": None, - "study_id": None, + "statement_id": None, }, "warnings": [], - "study_ids": [], - "studies": [], + "statement_ids": [], + "statements": [], "service_meta_": ServiceMeta(), } normalized_terms = await self._get_normalized_terms( - variation, disease, therapy, gene, study_id, response + variation, disease, therapy, gene, statement_id, response ) if normalized_terms is None: - return SearchStudiesService(**response) + return SearchStatementsService(**response) ( normalized_variation, normalized_disease, normalized_therapy, normalized_gene, - study, - valid_study_id, + statement, + valid_statement_id, ) = normalized_terms - if valid_study_id: - study_nodes = [study] - response["study_ids"].append(study["id"]) + if valid_statement_id: + statement_nodes = [statement] + response["statement_ids"].append(statement["id"]) else: - study_nodes = self._get_studies( + statement_nodes = self._get_statements( normalized_variation=normalized_variation, normalized_therapy=normalized_therapy, normalized_disease=normalized_disease, @@ -229,16 +230,16 @@ async def search_studies( start=start, limit=limit, ) - response["study_ids"] = [s["id"] for s in study_nodes] + response["statement_ids"] = [s["id"] for s in statement_nodes] - response["studies"] = self._get_nested_studies(study_nodes) + response["statements"] = self._get_nested_stmts(statement_nodes) - if not response["studies"]: + if not response["statements"]: response["warnings"].append( - "No studies found with the provided query parameters." + "No statements found with the provided query parameters." ) - return SearchStudiesService(**response) + return SearchStatementsService(**response) async def _get_normalized_terms( self, @@ -246,7 +247,7 @@ async def _get_normalized_terms( disease: str | None, therapy: str | None, gene: str | None, - study_id: str | None, + statement_id: str | None, response: dict, ) -> tuple | None: """Find normalized terms for queried concepts. @@ -255,11 +256,11 @@ async def _get_normalized_terms( :param disease: Disease (object_qualifier) query :param therapy: Therapy (object) query :param gene: Gene query - :param study_id: Study ID query + :param statement_id: Statement ID query :param response: The response for the query :return: A tuple containing the normalized concepts """ - if not any((variation, disease, therapy, gene, study_id)): + if not any((variation, disease, therapy, gene, statement_id)): response["warnings"].append("No query parameters were provided.") return None @@ -291,16 +292,18 @@ async def _get_normalized_terms( else: normalized_gene = None - # Check that queried study_id is valid - valid_study_id = None - study = None - if study_id: - response["query"]["study_id"] = study_id - study = self._get_study_by_id(study_id) - if study: - valid_study_id = study.get("id") + # Check that queried statement_id is valid + valid_statement_id = None + statement = None + if statement_id: + response["query"]["statement_id"] = statement_id + statement = self._get_stmt_by_id(statement_id) + if statement: + valid_statement_id = statement.get("id") else: - response["warnings"].append(f"Study: {study_id} does not exist.") + response["warnings"].append( + f"Statement: {statement_id} does not exist." + ) # If queried concept is given check that it is normalized / valid if ( @@ -308,7 +311,7 @@ async def _get_normalized_terms( or (therapy and not normalized_therapy) or (disease and not normalized_disease) or (gene and not normalized_gene) - or (study_id and not valid_study_id) + or (statement_id and not valid_statement_id) ): return None @@ -317,8 +320,8 @@ async def _get_normalized_terms( normalized_disease, normalized_therapy, normalized_gene, - study, - valid_study_id, + statement, + valid_statement_id, ) def _get_normalized_therapy(self, therapy: str, warnings: list[str]) -> str | None: @@ -381,23 +384,23 @@ def _get_normalized_gene(self, gene: str, warnings: list[str]) -> str | None: warnings.append(f"Gene Normalizer unable to normalize: {gene}") return normalized_gene_id - def _get_study_by_id(self, study_id: str) -> Node | None: - """Get a Study node by ID. + def _get_stmt_by_id(self, statement_id: str) -> Node | None: + """Get a Statement node by ID. - :param study_id: Study ID to retrieve - :return: Study node if successful + :param statement_id: Statement ID to retrieve + :return: Statement node if successful """ query = """ MATCH (s:Statement) - WHERE toLower(s.id) = toLower($study_id) + WHERE toLower(s.id) = toLower($statement_id) RETURN s """ - records = self.driver.execute_query(query, study_id=study_id).records + records = self.driver.execute_query(query, statement_id=statement_id).records if not records: return None return records[0]["s"] - def _get_studies( + def _get_statements( self, start: int, limit: int | None, @@ -406,7 +409,7 @@ def _get_studies( normalized_disease: str | None = None, normalized_gene: str | None = None, ) -> list[Node]: - """Get studies that match the intersection of provided concepts. + """Get statements that match the intersection of provided concepts. :param start: Index of first result to fetch. Calling context should've already checked that it's nonnegative. @@ -416,7 +419,8 @@ def _get_studies( :param normalized_therapy: normalized therapy concept ID :param normalized_disease: normalized disease concept ID :param normalized_gene: normalized gene concept ID - :return: List of Study nodes that match the intersection of the given parameters + :return: List of Statement nodes that match the intersection of the given + parameters """ query = "MATCH (s:Statement)" params: dict[str, str | int] = {} @@ -464,36 +468,36 @@ def _get_studies( return [s[0] for s in self.driver.execute_query(query, params).records] - def _get_nested_studies(self, study_nodes: list[Node]) -> list[dict]: - """Get a list of nested studies. + def _get_nested_stmts(self, statement_nodes: list[Node]) -> list[dict]: + """Get a list of nested statements. - :param study_nodes: A list of Study Nodes - :return: A list of nested studies + :param statement_nodes: A list of Statement Nodes + :return: A list of nested statements """ - nested_studies = [] - added_studies = set() - for s in study_nodes: + nested_stmts = [] + added_stmts = set() + for s in statement_nodes: s_id = s.get("id") - if s_id not in added_studies: + if s_id not in added_stmts: try: - nested_study = self._get_nested_study(s) + nested_stmt = self._get_nested_stmt(s) except ValidationError as e: logger.error("%s: %s", s_id, e) else: - if nested_study: - nested_studies.append(nested_study) - added_studies.add(s_id) + if nested_stmt: + nested_stmts.append(nested_stmt) + added_stmts.add(s_id) - return nested_studies + return nested_stmts - def _get_nested_study(self, study_node: Node) -> dict: - """Get information related to a study + def _get_nested_stmt(self, stmt_node: Node) -> dict: + """Get information related to a statement Only VariantTherapeuticResponseStudyStatement are supported at the moment - :param study_node: Neo4j graph node for study - :return: Nested study + :param stmt_node: Neo4j graph node for statement + :return: Nested statement """ - if study_node["type"] != "VariantTherapeuticResponseStudyStatement": + if stmt_node["type"] != "VariantTherapeuticResponseStudyStatement": return {} params = { @@ -503,16 +507,18 @@ def _get_nested_study(self, study_node: Node) -> dict: "reportedIn": [], "specifiedBy": None, } - params.update(study_node) - study_id = study_node["id"] + params.update(stmt_node) + statement_id = stmt_node["id"] - # Get relationship and nodes for a study + # Get relationship and nodes for a statement query = """ - MATCH (s:Statement { id: $study_id }) + MATCH (s:Statement { id: $statement_id }) OPTIONAL MATCH (s)-[r]-(n) RETURN type(r) as r_type, n; """ - nodes_and_rels = self.driver.execute_query(query, study_id=study_id).records + nodes_and_rels = self.driver.execute_query( + query, statement_id=statement_id + ).records for item in nodes_and_rels: data = item.data() @@ -525,11 +531,9 @@ def _get_nested_study(self, study_node: Node) -> dict: params["subjectVariant"] = self._get_cat_var(node) elif rel_type == "HAS_GENE_CONTEXT": params["geneContextQualifier"] = self._get_gene_context_qualifier( - study_id - ) - params["alleleOriginQualifier"] = study_node.get( - "alleleOriginQualifier" + statement_id ) + params["alleleOriginQualifier"] = stmt_node.get("alleleOriginQualifier") elif rel_type == "IS_SPECIFIED_BY": node["reportedIn"] = [self._get_method_document(node["id"])] params["specifiedBy"] = Method(**node) @@ -660,28 +664,28 @@ def _get_cat_var(self, node: dict) -> CategoricalVariant: ) return CategoricalVariant(**node) - def _get_gene_context_qualifier(self, study_id: str) -> Gene | None: - """Get gene context qualifier data for a study + def _get_gene_context_qualifier(self, statement_id: str) -> Gene | None: + """Get gene context qualifier data for a statement - :param study_id: ID of study node + :param statement_id: ID of statement node :return Gene context qualifier data """ query = """ - MATCH (s:Statement { id: $study_id }) -[:HAS_GENE_CONTEXT] -> (g:Gene) + MATCH (s:Statement { id: $statement_id }) -[:HAS_GENE_CONTEXT] -> (g:Gene) RETURN g """ - results = self.driver.execute_query(query, study_id=study_id) + results = self.driver.execute_query(query, statement_id=statement_id) if not results.records: logger.error( - "Unable to complete oncogenicity study qualifier lookup for study_id %s", - study_id, + "Unable to complete gene context qualifier lookup for statement_id %s", + statement_id, ) return None if len(results.records) > 1: - # TODO should this be an error? can studies have multiple gene contexts? + # TODO should this be an error? can statements have multiple gene contexts? logger.error( - "Encountered multiple matches for oncogenicity study qualifier lookup for study_id %s", - study_id, + "Encountered multiple matches for gene context qualifier lookup for statement_id %s", + statement_id, ) return None @@ -813,13 +817,13 @@ def _get_therapeutic_agent(self, in_ta_params: dict) -> TherapeuticAgent: ta_params["extensions"] = extensions return TherapeuticAgent(**ta_params) - async def batch_search_studies( + async def batch_search_statements( self, variations: list[str] | None = None, start: int = 0, limit: int | None = None, - ) -> BatchSearchStudiesService: - """Fetch all studies associated with any of the provided variation description + ) -> BatchSearchStatementsService: + """Fetch all statements associated with any of the provided variation description strings. Because this method could be expanded to include other kinds of search terms, @@ -827,16 +831,16 @@ async def batch_search_studies( >>> from metakb.query import QueryHandler >>> qh = QueryHandler() - >>> response = await qh.batch_search_studies(["EGFR L858R"]) - >>> response.study_ids[:3] + >>> response = await qh.batch_search_statements(["EGFR L858R"]) + >>> response.statement_ids[:3] ['civic.eid:229', 'civic.eid:3811', 'moa.assertion:268'] All terms are normalized, so redundant terms don't alter search results: - >>> redundant_response = await qh.batch_search_studies( + >>> redundant_response = await qh.batch_search_statements( ... ["EGFR L858R", "NP_005219.2:p.Leu858Arg"] ... ) - >>> len(response.study_ids) == len(redundant_response.study_ids) + >>> len(response.statement_ids) == len(redundant_response.statement_ids) True :param variations: a list of variation description strings, e.g. @@ -844,7 +848,7 @@ async def batch_search_studies( :param start: Index of first result to fetch. Must be nonnegative. :param limit: Max number of results to fetch. Must be nonnegative. Revert to default defined at class initialization if not given. - :return: response object including all matching studies + :return: response object including all matching statements :raise ValueError: if ``start`` or ``limit`` are nonnegative """ if start < 0: @@ -854,8 +858,8 @@ async def batch_search_studies( msg = "Can't limit results to less than a negative number." raise ValueError(msg) - response = BatchSearchStudiesService( - query=BatchSearchStudiesQuery(variations=[]), + response = BatchSearchStatementsService( + query=BatchSearchStatementsQuery(variations=[]), service_meta_=ServiceMeta(), warnings=[], ) @@ -897,10 +901,10 @@ async def batch_search_studies( """ with self.driver.session() as session: result = session.run(query, v_ids=variation_ids, skip=start, limit=limit) - study_nodes = [r[0] for r in result] - response.study_ids = [n["id"] for n in study_nodes] - studies = self._get_nested_studies(study_nodes) - response.studies = [ - VariantTherapeuticResponseStudyStatement(**s) for s in studies + statement_nodes = [r[0] for r in result] + response.statement_ids = [n["id"] for n in statement_nodes] + stmts = self._get_nested_stmts(statement_nodes) + response.statements = [ + VariantTherapeuticResponseStudyStatement(**s) for s in stmts ] return response diff --git a/src/metakb/schemas/api.py b/src/metakb/schemas/api.py index 383f21cc..9b1fc668 100644 --- a/src/metakb/schemas/api.py +++ b/src/metakb/schemas/api.py @@ -30,23 +30,23 @@ class ServiceMeta(BaseModel): ) -class SearchStudiesQuery(BaseModel): - """Queries for the Search Studies Endpoint.""" +class SearchStatementsQuery(BaseModel): + """Queries for the Search Statements Endpoint.""" variation: StrictStr | None = None disease: StrictStr | None = None therapy: StrictStr | None = None gene: StrictStr | None = None - study_id: StrictStr | None = None + statement_id: StrictStr | None = None -class SearchStudiesService(BaseModel): - """Define model for Search Studies Endpoint Response.""" +class SearchStatementsService(BaseModel): + """Define model for Search Statements Endpoint Response.""" - query: SearchStudiesQuery + query: SearchStatementsQuery warnings: list[StrictStr] = [] - study_ids: list[StrictStr] = [] - studies: list[VariantTherapeuticResponseStudyStatement] = [] + statement_ids: list[StrictStr] = [] + statements: list[VariantTherapeuticResponseStudyStatement] = [] service_meta_: ServiceMeta @@ -57,17 +57,17 @@ class NormalizedQuery(BaseModel): normalized_id: StrictStr | None = None -class BatchSearchStudiesQuery(BaseModel): - """Define query as reported in batch search studies endpoint.""" +class BatchSearchStatementsQuery(BaseModel): + """Define query as reported in batch search statements endpoint.""" variations: list[NormalizedQuery] = [] -class BatchSearchStudiesService(BaseModel): - """Define response model for batch search studies endpoint response.""" +class BatchSearchStatementsService(BaseModel): + """Define response model for batch search statements endpoint response.""" - query: BatchSearchStudiesQuery + query: BatchSearchStatementsQuery warnings: list[StrictStr] = [] - study_ids: list[StrictStr] = [] - studies: list[VariantTherapeuticResponseStudyStatement] = [] + statement_ids: list[StrictStr] = [] + statements: list[VariantTherapeuticResponseStudyStatement] = [] service_meta_: ServiceMeta diff --git a/src/metakb/transformers/base.py b/src/metakb/transformers/base.py index a583dc37..39bb4c0f 100644 --- a/src/metakb/transformers/base.py +++ b/src/metakb/transformers/base.py @@ -109,7 +109,7 @@ class ViccConceptVocab(BaseModel): class TransformedData(BaseModel): """Define model for transformed data""" - studies: list[VariantTherapeuticResponseStudyStatement] = [] + statements: list[VariantTherapeuticResponseStudyStatement] = [] categorical_variants: list[CategoricalVariant] = [] variations: list[Allele] = [] genes: list[Gene] = [] diff --git a/src/metakb/transformers/civic.py b/src/metakb/transformers/civic.py index d7d25360..604638de 100644 --- a/src/metakb/transformers/civic.py +++ b/src/metakb/transformers/civic.py @@ -211,16 +211,14 @@ async def transform(self, harvested_data: CivicHarvestedData) -> None: ] self._add_categorical_variants(mps, mp_id_to_v_id_mapping) - # Add variant therapeutic response study data. Will update `studies` - self._add_variant_therapeutic_response_studies( - evidence_items, mp_id_to_v_id_mapping - ) + # Add variant therapeutic response study statement data. Will update `statements` + self._add_variant_tr_study_stmts(evidence_items, mp_id_to_v_id_mapping) - def _add_variant_therapeutic_response_studies( + def _add_variant_tr_study_stmts( self, records: list[dict], mp_id_to_v_id_mapping: dict ) -> None: - """Create Variant Therapeutic Response Studies from CIViC Evidence Items. - Will add associated values to ``processed_data`` instance variable + """Create Variant Therapeutic Response Study Statements from CIViC Evidence + Items. Will add associated values to ``processed_data`` instance variable (``therapeutic_procedures``, ``conditions``, and ``documents``). ``able_to_normalize`` and ``unable_to_normalize`` will also be mutated for associated therapeutic_procedures and conditions. @@ -333,7 +331,7 @@ def _add_variant_therapeutic_response_studies( specifiedBy=self.processed_data.methods[0], reportedIn=[document], ) - self.processed_data.studies.append(statement) + self.processed_data.statements.append(statement) def _get_evidence_direction(self, direction: str) -> Direction | None: """Get the normalized evidence direction diff --git a/src/metakb/transformers/moa.py b/src/metakb/transformers/moa.py index ee64b538..ce918a83 100644 --- a/src/metakb/transformers/moa.py +++ b/src/metakb/transformers/moa.py @@ -84,15 +84,13 @@ async def transform(self, harvested_data: MoaHarvestedData) -> None: await self._add_categorical_variants(harvested_data.variants) self._add_documents(harvested_data.sources) - # Add variant therapeutic response study data. Will update `studies` - await self._add_variant_therapeutic_response_studies(harvested_data.assertions) + # Add variant therapeutic response study statement data. Will update `statements` + await self._add_variant_tr_study_stmts(harvested_data.assertions) - async def _add_variant_therapeutic_response_studies( - self, assertions: list[dict] - ) -> None: - """Create Variant Therapeutic Response Studies from MOA assertions. + async def _add_variant_tr_study_stmts(self, assertions: list[dict]) -> None: + """Create Variant Therapeutic Response Study Statements from MOA assertions. Will add associated values to ``processed_data`` instance variable - (``therapeutic_procedures``, ``conditions``, and ``studies``). + (``therapeutic_procedures``, ``conditions``, and ``statements``). ``able_to_normalize`` and ``unable_to_normalize`` will also be mutated for associated therapeutic_procedures and conditions. @@ -213,7 +211,7 @@ async def _add_variant_therapeutic_response_studies( specifiedBy=self.processed_data.methods[0], reportedIn=[document], ) - self.processed_data.studies.append(statement) + self.processed_data.statements.append(statement) async def _add_categorical_variants(self, variants: list[dict]) -> None: """Create Categorical Variant objects for all MOA variant records. diff --git a/tests/conftest.py b/tests/conftest.py index 7a45c111..c5dbee10 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -309,7 +309,7 @@ def civic_source592(): @pytest.fixture(scope="session") -def civic_eid2997_study( +def civic_eid2997_study_stmt( civic_mpid33, civic_tid146, civic_did8, @@ -317,7 +317,7 @@ def civic_eid2997_study( civic_method, civic_source592, ): - """Create CIVIC EID2997 Statement test fixture. Uses TherapeuticAgent.""" + """Create CIVIC EID2997 Study Statement test fixture. Uses TherapeuticAgent.""" return { "id": "civic.eid:2997", "type": "VariantTherapeuticResponseStudyStatement", @@ -876,8 +876,10 @@ def civic_did11(): @pytest.fixture(scope="session") -def civic_eid816_study(civic_mpid12, civic_tsg, civic_did11, civic_gid5, civic_method): - """Create CIVIC EID816 study test fixture. Uses TherapeuticSubstituteGroup.""" +def civic_eid816_study_stmt( + civic_mpid12, civic_tsg, civic_did11, civic_gid5, civic_method +): + """Create CIVIC EID816 study statement test fixture. Uses TherapeuticSubstituteGroup.""" return { "id": "civic.eid:816", "type": "VariantTherapeuticResponseStudyStatement", @@ -908,14 +910,14 @@ def civic_eid816_study(civic_mpid12, civic_tsg, civic_did11, civic_gid5, civic_m @pytest.fixture(scope="session") -def civic_eid9851_study( +def civic_eid9851_study_stmt( civic_mpid12, civic_ct, civic_did11, civic_gid5, civic_method, ): - """Create CIVIC EID9851 study test fixture. Uses CombinationTherapy.""" + """Create CIVIC EID9851 study statement test fixture. Uses CombinationTherapy.""" return { "id": "civic.eid:9851", "type": "VariantTherapeuticResponseStudyStatement", @@ -1726,7 +1728,7 @@ def pmid_27819322(): @pytest.fixture(scope="session") -def moa_aid66_study( +def moa_aid66_study_stmt( moa_vid66, moa_abl1, moa_imatinib, @@ -1734,7 +1736,7 @@ def moa_aid66_study( moa_method, moa_source45, ): - """Create a Variant Therapeutic Response Study test fixture for MOA Assertion 66.""" + """Create a Variant Therapeutic Response Study Statement test fixture for MOA Assertion 66.""" return { "id": "moa.assertion:66", "description": "T315I mutant ABL1 in p210 BCR-ABL cells resulted in retained high levels of phosphotyrosine at increasing concentrations of inhibitor STI-571, whereas wildtype appropriately received inhibition.", @@ -2111,9 +2113,9 @@ def _check(actual_data: list, test_data: list, is_cdm: bool = False) -> None: def check_transformed_cdm(assertion_checks): """Test fixture to compare CDM transformations.""" - def check_transformed_cdm(data, studies, transformed_file): + def check_transformed_cdm(data, statements, transformed_file): """Test that transform to CDM works correctly.""" - assertion_checks(data["studies"], studies, is_cdm=True) + assertion_checks(data["statements"], statements, is_cdm=True) transformed_file.unlink() return check_transformed_cdm diff --git a/tests/unit/database/test_database.py b/tests/unit/database/test_database.py index 85a9c0db..9ff2b282 100644 --- a/tests/unit/database/test_database.py +++ b/tests/unit/database/test_database.py @@ -89,8 +89,8 @@ def _check_function( @pytest.fixture(scope="module") -def check_study_relation(driver: Driver): - """Check that node is used in a study.""" +def check_statement_relation(driver: Driver): + """Check that node is used in a statement.""" def _check_function(value_label: str): query = f""" @@ -415,8 +415,8 @@ def test_therapeutic_procedure_rules( ): """Verify property and relationship rules for Therapeutic Procedure nodes.""" check_unique_property("TherapeuticProcedure", "id") - # min_rels is 0 because TherapeuticAgent may not be attached to study directly, but - # through CombinationTherapy and TherapeuticSubstituteGroup + # min_rels is 0 because TherapeuticAgent may not be attached to statement directly, + # but through CombinationTherapy and TherapeuticSubstituteGroup check_relation_count( "TherapeuticProcedure", "Statement", @@ -528,13 +528,13 @@ def test_condition_rules( check_node_props(disease, civic_did8, expected_keys, extension_names) -def test_study_rules( +def test_statement_rules( driver: Driver, check_unique_property, check_relation_count, check_node_labels, get_node_by_id, - civic_eid2997_study, + civic_eid2997_study_stmt, check_node_props, ): """Verify property and relationship rules for Statement nodes.""" @@ -563,7 +563,7 @@ def test_study_rules( record = s.run(cite_query).single() assert record.values()[0] == 0 - study = get_node_by_id(civic_eid2997_study["id"]) + statement = get_node_by_id(civic_eid2997_study_stmt["id"]) expected_keys = { "id", "description", @@ -572,11 +572,11 @@ def test_study_rules( "alleleOriginQualifier", "type", } - civic_eid2997_study_cp = civic_eid2997_study.copy() - civic_eid2997_study_cp["alleleOriginQualifier"] = civic_eid2997_study_cp[ + civic_eid2997_ss_cp = civic_eid2997_study_stmt.copy() + civic_eid2997_ss_cp["alleleOriginQualifier"] = civic_eid2997_ss_cp[ "alleleOriginQualifier" ] - check_node_props(study, civic_eid2997_study_cp, expected_keys) + check_node_props(statement, civic_eid2997_ss_cp, expected_keys) def test_document_rules( diff --git a/tests/unit/search/test_batch_search_statements.py b/tests/unit/search/test_batch_search_statements.py new file mode 100644 index 00000000..512ac320 --- /dev/null +++ b/tests/unit/search/test_batch_search_statements.py @@ -0,0 +1,137 @@ +"""Test batch search function.""" + +import pytest + +from metakb.query import QueryHandler +from metakb.schemas.api import NormalizedQuery + +from .utils import assert_no_match, find_and_check_stmt + + +@pytest.mark.asyncio(scope="module") +async def test_batch_search( + query_handler: QueryHandler, + assertion_checks, + civic_eid2997_study_stmt, + civic_eid816_study_stmt, +): + """Test batch search statements method.""" + resp = await query_handler.batch_search_statements([]) + assert resp.statements == resp.statement_ids == [] + assert resp.warnings == [] + + assert_no_match(await query_handler.batch_search_statements(["gibberish variant"])) + + braf_va_id = "ga4gh:VA.Otc5ovrw906Ack087o1fhegB4jDRqCAe" + braf_response = await query_handler.batch_search_statements([braf_va_id]) + assert braf_response.query.variations == [ + NormalizedQuery( + term=braf_va_id, + normalized_id=braf_va_id, + ) + ] + find_and_check_stmt(braf_response, civic_eid816_study_stmt, assertion_checks) + + redundant_braf_response = await query_handler.batch_search_statements( + [braf_va_id, "NC_000007.13:g.140453136A>T"] + ) + assert len(redundant_braf_response.query.variations) == 2 + assert ( + NormalizedQuery( + term=braf_va_id, + normalized_id=braf_va_id, + ) + in redundant_braf_response.query.variations + ) + assert ( + NormalizedQuery( + term="NC_000007.13:g.140453136A>T", + normalized_id=braf_va_id, + ) + in redundant_braf_response.query.variations + ) + + find_and_check_stmt( + redundant_braf_response, civic_eid816_study_stmt, assertion_checks + ) + assert len(braf_response.statement_ids) == len( + redundant_braf_response.statement_ids + ) + + braf_egfr_response = await query_handler.batch_search_statements( + [braf_va_id, "EGFR L858R"] + ) + find_and_check_stmt(braf_egfr_response, civic_eid816_study_stmt, assertion_checks) + find_and_check_stmt(braf_egfr_response, civic_eid2997_study_stmt, assertion_checks) + assert len(braf_egfr_response.statement_ids) > len(braf_response.statement_ids) + + +@pytest.mark.asyncio(scope="module") +async def test_paginate(query_handler: QueryHandler, normalizers): + """Test pagination parameters.""" + braf_va_id = "ga4gh:VA.Otc5ovrw906Ack087o1fhegB4jDRqCAe" + full_response = await query_handler.batch_search_statements([braf_va_id]) + paged_response = await query_handler.batch_search_statements([braf_va_id], start=1) + # should be almost the same, just off by 1 + assert len(paged_response.statement_ids) == len(full_response.statement_ids) - 1 + assert paged_response.statement_ids == full_response.statement_ids[1:] + + # check that page limit > response doesn't affect response + huge_page_response = await query_handler.batch_search_statements( + [braf_va_id], limit=1000 + ) + assert len(huge_page_response.statement_ids) == len(full_response.statement_ids) + assert huge_page_response.statement_ids == full_response.statement_ids + + # get last item + last_response = await query_handler.batch_search_statements( + [braf_va_id], start=len(full_response.statement_ids) - 1 + ) + assert len(last_response.statement_ids) == 1 + assert last_response.statement_ids[0] == full_response.statement_ids[-1] + + # test limit + min_response = await query_handler.batch_search_statements([braf_va_id], limit=1) + assert min_response.statement_ids[0] == full_response.statement_ids[0] + + # test limit and start + other_min_response = await query_handler.batch_search_statements( + [braf_va_id], start=1, limit=1 + ) + assert other_min_response.statement_ids[0] == full_response.statement_ids[1] + + # test limit of 0 + empty_response = await query_handler.batch_search_statements([braf_va_id], limit=0) + assert len(empty_response.statement_ids) == 0 + + # test raises exceptions + with pytest.raises(ValueError, match="Can't start from an index of less than 0."): + await query_handler.batch_search_statements([braf_va_id], start=-1) + with pytest.raises( + ValueError, match="Can't limit results to less than a negative number." + ): + await query_handler.batch_search_statements([braf_va_id], limit=-1) + + # test default limit + limited_query_handler = QueryHandler(normalizers=normalizers, default_page_limit=1) + default_limited_response = await limited_query_handler.batch_search_statements( + [braf_va_id] + ) + assert len(default_limited_response.statement_ids) == 1 + assert default_limited_response.statement_ids[0] == full_response.statement_ids[0] + + # test overrideable + less_limited_response = await limited_query_handler.batch_search_statements( + [braf_va_id], limit=2 + ) + assert len(less_limited_response.statement_ids) == 2 + assert less_limited_response.statement_ids == full_response.statement_ids[:2] + + # test default limit and skip + skipped_limited_response = await limited_query_handler.batch_search_statements( + [braf_va_id], start=1 + ) + assert len(skipped_limited_response.statement_ids) == 1 + assert skipped_limited_response.statement_ids[0] == full_response.statement_ids[1] + + limited_query_handler.driver.close() diff --git a/tests/unit/search/test_batch_search_studies.py b/tests/unit/search/test_batch_search_studies.py deleted file mode 100644 index 7af71357..00000000 --- a/tests/unit/search/test_batch_search_studies.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Test batch search function.""" - -import pytest - -from metakb.query import QueryHandler -from metakb.schemas.api import NormalizedQuery - -from .utils import assert_no_match, find_and_check_study - - -@pytest.mark.asyncio(scope="module") -async def test_batch_search( - query_handler: QueryHandler, - assertion_checks, - civic_eid2997_study, - civic_eid816_study, -): - """Test batch search studies method.""" - resp = await query_handler.batch_search_studies([]) - assert resp.studies == resp.study_ids == [] - assert resp.warnings == [] - - assert_no_match(await query_handler.batch_search_studies(["gibberish variant"])) - - braf_va_id = "ga4gh:VA.Otc5ovrw906Ack087o1fhegB4jDRqCAe" - braf_response = await query_handler.batch_search_studies([braf_va_id]) - assert braf_response.query.variations == [ - NormalizedQuery( - term=braf_va_id, - normalized_id=braf_va_id, - ) - ] - find_and_check_study(braf_response, civic_eid816_study, assertion_checks) - - redundant_braf_response = await query_handler.batch_search_studies( - [braf_va_id, "NC_000007.13:g.140453136A>T"] - ) - assert len(redundant_braf_response.query.variations) == 2 - assert ( - NormalizedQuery( - term=braf_va_id, - normalized_id=braf_va_id, - ) - in redundant_braf_response.query.variations - ) - assert ( - NormalizedQuery( - term="NC_000007.13:g.140453136A>T", - normalized_id=braf_va_id, - ) - in redundant_braf_response.query.variations - ) - - find_and_check_study(redundant_braf_response, civic_eid816_study, assertion_checks) - assert len(braf_response.study_ids) == len(redundant_braf_response.study_ids) - - braf_egfr_response = await query_handler.batch_search_studies( - [braf_va_id, "EGFR L858R"] - ) - find_and_check_study(braf_egfr_response, civic_eid816_study, assertion_checks) - find_and_check_study(braf_egfr_response, civic_eid2997_study, assertion_checks) - assert len(braf_egfr_response.study_ids) > len(braf_response.study_ids) - - -@pytest.mark.asyncio(scope="module") -async def test_paginate(query_handler: QueryHandler, normalizers): - """Test pagination parameters.""" - braf_va_id = "ga4gh:VA.Otc5ovrw906Ack087o1fhegB4jDRqCAe" - full_response = await query_handler.batch_search_studies([braf_va_id]) - paged_response = await query_handler.batch_search_studies([braf_va_id], start=1) - # should be almost the same, just off by 1 - assert len(paged_response.study_ids) == len(full_response.study_ids) - 1 - assert paged_response.study_ids == full_response.study_ids[1:] - - # check that page limit > response doesn't affect response - huge_page_response = await query_handler.batch_search_studies( - [braf_va_id], limit=1000 - ) - assert len(huge_page_response.study_ids) == len(full_response.study_ids) - assert huge_page_response.study_ids == full_response.study_ids - - # get last item - last_response = await query_handler.batch_search_studies( - [braf_va_id], start=len(full_response.study_ids) - 1 - ) - assert len(last_response.study_ids) == 1 - assert last_response.study_ids[0] == full_response.study_ids[-1] - - # test limit - min_response = await query_handler.batch_search_studies([braf_va_id], limit=1) - assert min_response.study_ids[0] == full_response.study_ids[0] - - # test limit and start - other_min_response = await query_handler.batch_search_studies( - [braf_va_id], start=1, limit=1 - ) - assert other_min_response.study_ids[0] == full_response.study_ids[1] - - # test limit of 0 - empty_response = await query_handler.batch_search_studies([braf_va_id], limit=0) - assert len(empty_response.study_ids) == 0 - - # test raises exceptions - with pytest.raises(ValueError, match="Can't start from an index of less than 0."): - await query_handler.batch_search_studies([braf_va_id], start=-1) - with pytest.raises( - ValueError, match="Can't limit results to less than a negative number." - ): - await query_handler.batch_search_studies([braf_va_id], limit=-1) - - # test default limit - limited_query_handler = QueryHandler(normalizers=normalizers, default_page_limit=1) - default_limited_response = await limited_query_handler.batch_search_studies( - [braf_va_id] - ) - assert len(default_limited_response.study_ids) == 1 - assert default_limited_response.study_ids[0] == full_response.study_ids[0] - - # test overrideable - less_limited_response = await limited_query_handler.batch_search_studies( - [braf_va_id], limit=2 - ) - assert len(less_limited_response.study_ids) == 2 - assert less_limited_response.study_ids == full_response.study_ids[:2] - - # test default limit and skip - skipped_limited_response = await limited_query_handler.batch_search_studies( - [braf_va_id], start=1 - ) - assert len(skipped_limited_response.study_ids) == 1 - assert skipped_limited_response.study_ids[0] == full_response.study_ids[1] - - limited_query_handler.driver.close() diff --git a/tests/unit/search/test_search_statements.py b/tests/unit/search/test_search_statements.py new file mode 100644 index 00000000..5e8aaac5 --- /dev/null +++ b/tests/unit/search/test_search_statements.py @@ -0,0 +1,324 @@ +"""Test search statement methods""" + +import pytest +from ga4gh.core.entity_models import Extension + +from metakb.normalizers import VICC_NORMALIZER_DATA +from metakb.query import QueryHandler + +from .utils import assert_no_match, find_and_check_stmt + + +def _get_normalizer_id(extensions: list[Extension]) -> str | None: + """Get normalized ID from list of extensions + + :param extensions: List of extensions + :return: Normalized concept ID if found in extensions + """ + normalizer_id = None + for ext in extensions: + if ext.name == VICC_NORMALIZER_DATA: + normalizer_id = ext.value["id"] + break + return normalizer_id + + +def assert_general_search_stmts(response): + """Check that general search_statements queries return a valid response""" + len_stmt_id_matches = len(response.statement_ids) + assert len_stmt_id_matches > 0 + len_stmts = len(response.statements) + assert len_stmts > 0 + assert len_stmt_id_matches == len_stmts + + +@pytest.mark.asyncio(scope="module") +async def test_civic_eid2997(query_handler, civic_eid2997_study_stmt, assertion_checks): + """Test that search_statements method works correctly for CIViC EID2997""" + resp = await query_handler.search_statements( + statement_id=civic_eid2997_study_stmt["id"] + ) + assert resp.statement_ids == [civic_eid2997_study_stmt["id"]] + resp_stmts = [s.model_dump(exclude_none=True) for s in resp.statements] + assertion_checks(resp_stmts, [civic_eid2997_study_stmt]) + assert resp.warnings == [] + + resp = await query_handler.search_statements(variation="EGFR L858R") + find_and_check_stmt(resp, civic_eid2997_study_stmt, assertion_checks) + + resp = await query_handler.search_statements( + variation="ga4gh:VA.S41CcMJT2bcd8R4-qXZWH1PoHWNtG2PZ" + ) + find_and_check_stmt(resp, civic_eid2997_study_stmt, assertion_checks) + + # genomic query + resp = await query_handler.search_statements(variation="7-55259515-T-G") + find_and_check_stmt(resp, civic_eid2997_study_stmt, assertion_checks) + + resp = await query_handler.search_statements(therapy="ncit:C66940") + find_and_check_stmt(resp, civic_eid2997_study_stmt, assertion_checks) + + resp = await query_handler.search_statements(gene="EGFR") + find_and_check_stmt(resp, civic_eid2997_study_stmt, assertion_checks) + + resp = await query_handler.search_statements(disease="nsclc") + find_and_check_stmt(resp, civic_eid2997_study_stmt, assertion_checks) + + # We should not find CIViC EID2997 using these queries + resp = await query_handler.search_statements(statement_id="civic.eid:3017") + find_and_check_stmt(resp, civic_eid2997_study_stmt, assertion_checks, False) + + resp = await query_handler.search_statements(variation="BRAF V600E") + find_and_check_stmt(resp, civic_eid2997_study_stmt, assertion_checks, False) + + resp = await query_handler.search_statements(therapy="imatinib") + find_and_check_stmt(resp, civic_eid2997_study_stmt, assertion_checks, False) + + resp = await query_handler.search_statements(gene="BRAF") + find_and_check_stmt(resp, civic_eid2997_study_stmt, assertion_checks, False) + + resp = await query_handler.search_statements(disease="DOID:9253") + find_and_check_stmt(resp, civic_eid2997_study_stmt, assertion_checks, False) + + +@pytest.mark.asyncio(scope="module") +async def test_civic816(query_handler, civic_eid816_study_stmt, assertion_checks): + """Test that search_statements method works correctly for CIViC EID816""" + resp = await query_handler.search_statements( + statement_id=civic_eid816_study_stmt["id"] + ) + assert resp.statement_ids == [civic_eid816_study_stmt["id"]] + resp_stmts = [s.model_dump(exclude_none=True) for s in resp.statements] + assertion_checks(resp_stmts, [civic_eid816_study_stmt]) + assert resp.warnings == [] + + # Try querying based on therapies in substitutes + resp = await query_handler.search_statements(therapy="Cetuximab") + find_and_check_stmt(resp, civic_eid816_study_stmt, assertion_checks) + + resp = await query_handler.search_statements(therapy="Panitumumab") + find_and_check_stmt(resp, civic_eid816_study_stmt, assertion_checks) + + +@pytest.mark.asyncio(scope="module") +async def test_civic9851(query_handler, civic_eid9851_study_stmt, assertion_checks): + """Test that search_statements method works correctly for CIViC EID9851""" + resp = await query_handler.search_statements( + statement_id=civic_eid9851_study_stmt["id"] + ) + assert resp.statement_ids == [civic_eid9851_study_stmt["id"]] + resp_stmts = [s.model_dump(exclude_none=True) for s in resp.statements] + assertion_checks(resp_stmts, [civic_eid9851_study_stmt]) + assert resp.warnings == [] + + # Try querying based on therapies in components + resp = await query_handler.search_statements(therapy="Encorafenib") + find_and_check_stmt(resp, civic_eid9851_study_stmt, assertion_checks) + + resp = await query_handler.search_statements(therapy="Cetuximab") + find_and_check_stmt(resp, civic_eid9851_study_stmt, assertion_checks) + + +@pytest.mark.asyncio(scope="module") +async def test_moa_66(query_handler, moa_aid66_study_stmt, assertion_checks): + """Test that search_statements method works correctly for MOA Assertion 66""" + resp = await query_handler.search_statements( + statement_id=moa_aid66_study_stmt["id"] + ) + assert resp.statement_ids == [moa_aid66_study_stmt["id"]] + resp_stmts = [s.model_dump(exclude_none=True) for s in resp.statements] + assertion_checks(resp_stmts, [moa_aid66_study_stmt]) + assert resp.warnings == [] + + resp = await query_handler.search_statements(variation="ABL1 Thr315Ile") + find_and_check_stmt(resp, moa_aid66_study_stmt, assertion_checks) + + resp = await query_handler.search_statements( + variation="ga4gh:VA.D6NzpWXKqBnbcZZrXNSXj4tMUwROKbsQ" + ) + find_and_check_stmt(resp, moa_aid66_study_stmt, assertion_checks) + + resp = await query_handler.search_statements(therapy="rxcui:282388") + find_and_check_stmt(resp, moa_aid66_study_stmt, assertion_checks) + + resp = await query_handler.search_statements(gene="ncbigene:25") + find_and_check_stmt(resp, moa_aid66_study_stmt, assertion_checks) + + resp = await query_handler.search_statements(disease="CML") + find_and_check_stmt(resp, moa_aid66_study_stmt, assertion_checks) + + # We should not find MOA Assertion 67 using these queries + resp = await query_handler.search_statements(statement_id="moa.assertion:71") + find_and_check_stmt(resp, moa_aid66_study_stmt, assertion_checks, False) + + resp = await query_handler.search_statements(variation="BRAF V600E") + find_and_check_stmt(resp, moa_aid66_study_stmt, assertion_checks, False) + + resp = await query_handler.search_statements(therapy="Afatinib") + find_and_check_stmt(resp, moa_aid66_study_stmt, assertion_checks, False) + + resp = await query_handler.search_statements(gene="ABL2") + find_and_check_stmt(resp, moa_aid66_study_stmt, assertion_checks, False) + + resp = await query_handler.search_statements(disease="ncit:C2926") + find_and_check_stmt(resp, moa_aid66_study_stmt, assertion_checks, False) + + +@pytest.mark.asyncio(scope="module") +async def test_general_search_statements(query_handler): + """Test that queries do not return errors""" + resp = await query_handler.search_statements(variation="BRAF V600E") + assert_general_search_stmts(resp) + + resp = await query_handler.search_statements(variation="EGFR L858R") + assert_general_search_stmts(resp) + + resp = await query_handler.search_statements(disease="cancer") + assert_general_search_stmts(resp) + + # Case: Handling therapy for single therapeutic agent / combination / substitutes + resp = await query_handler.search_statements(therapy="Cetuximab") + assert_general_search_stmts(resp) + expected_therapy_id = "rxcui:318341" + for statement in resp.statements: + tp = statement.objectTherapeutic.root + if tp.type == "TherapeuticAgent": + assert _get_normalizer_id(tp.extensions) == expected_therapy_id + else: + therapeutics = ( + tp.components if tp.type == "CombinationTherapy" else tp.substitutes + ) + + found_expected = False + for therapeutic in therapeutics: + if _get_normalizer_id(therapeutic.extensions) == expected_therapy_id: + found_expected = True + break + assert found_expected + + resp = await query_handler.search_statements(gene="VHL") + assert_general_search_stmts(resp) + + # Case: multiple concepts provided + expected_variation_id = "ga4gh:VA._8jTS8nAvWwPZGOadQuD1o-tbbTQ5g3H" + expected_disease_id = "ncit:C2926" + expected_therapy_id = "ncit:C104732" + resp = await query_handler.search_statements( + variation=expected_variation_id, + disease=expected_disease_id, + therapy=expected_therapy_id, # Single Therapeutic Agent + ) + assert_general_search_stmts(resp) + + for statement in resp.statements: + assert ( + statement.subjectVariant.constraints[0].root.definingContext.root.id + == expected_variation_id + ) + assert ( + _get_normalizer_id(statement.objectTherapeutic.root.extensions) + == expected_therapy_id + ) + assert ( + _get_normalizer_id(statement.conditionQualifier.root.extensions) + == expected_disease_id + ) + + +@pytest.mark.asyncio(scope="module") +async def test_no_matches(query_handler): + """Test invalid queries""" + # invalid vrs variation prefix (digest is correct) + resp = await query_handler.search_statements( + variation="ga4gh:variation.TAARa2cxRHmOiij9UBwvW-noMDoOq2x9" + ) + assert_no_match(resp) + + # invalid id + resp = await query_handler.search_statements( + disease="ncit:C292632425235321524352435623462" + ) + assert_no_match(resp) + + # empty query + resp = await query_handler.search_statements() + assert_no_match(resp) + + # valid queries, but no matches with combination + resp = await query_handler.search_statements(variation="BRAF V600E", gene="EGFR") + assert_no_match(resp) + + +@pytest.mark.asyncio(scope="module") +async def test_paginate(query_handler: QueryHandler, normalizers): + """Test pagination parameters.""" + braf_va_id = "ga4gh:VA.Otc5ovrw906Ack087o1fhegB4jDRqCAe" + full_response = await query_handler.search_statements(variation=braf_va_id) + paged_response = await query_handler.search_statements( + variation=braf_va_id, start=1 + ) + # should be almost the same, just off by 1 + assert len(paged_response.statement_ids) == len(full_response.statement_ids) - 1 + assert paged_response.statement_ids == full_response.statement_ids[1:] + + # check that page limit > response doesn't affect response + huge_page_response = await query_handler.search_statements( + variation=braf_va_id, limit=1000 + ) + assert len(huge_page_response.statement_ids) == len(full_response.statement_ids) + assert huge_page_response.statement_ids == full_response.statement_ids + + # get last item + last_response = await query_handler.search_statements( + variation=braf_va_id, start=len(full_response.statement_ids) - 1 + ) + assert len(last_response.statement_ids) == 1 + assert last_response.statement_ids[0] == full_response.statement_ids[-1] + + # test limit + min_response = await query_handler.search_statements(variation=braf_va_id, limit=1) + assert min_response.statement_ids[0] == full_response.statement_ids[0] + + # test limit and start + other_min_response = await query_handler.search_statements( + variation=braf_va_id, start=1, limit=1 + ) + assert other_min_response.statement_ids[0] == full_response.statement_ids[1] + + # test limit of 0 + empty_response = await query_handler.search_statements( + variation=braf_va_id, limit=0 + ) + assert len(empty_response.statement_ids) == 0 + + # test raises exceptions + with pytest.raises(ValueError, match="Can't start from an index of less than 0."): + await query_handler.search_statements(variation=braf_va_id, start=-1) + with pytest.raises( + ValueError, match="Can't limit results to less than a negative number." + ): + await query_handler.search_statements(variation=braf_va_id, limit=-1) + + # test default limit + limited_query_handler = QueryHandler(normalizers=normalizers, default_page_limit=1) + default_limited_response = await limited_query_handler.search_statements( + variation=braf_va_id + ) + assert len(default_limited_response.statement_ids) == 1 + assert default_limited_response.statement_ids[0] == full_response.statement_ids[0] + + # test overrideable + less_limited_response = await limited_query_handler.search_statements( + variation=braf_va_id, limit=2 + ) + assert len(less_limited_response.statement_ids) == 2 + assert less_limited_response.statement_ids == full_response.statement_ids[:2] + + # test default limit and skip + skipped_limited_response = await limited_query_handler.search_statements( + variation=braf_va_id, start=1 + ) + assert len(skipped_limited_response.statement_ids) == 1 + assert skipped_limited_response.statement_ids[0] == full_response.statement_ids[1] + + limited_query_handler.driver.close() diff --git a/tests/unit/search/test_search_studies.py b/tests/unit/search/test_search_studies.py deleted file mode 100644 index 7e1694eb..00000000 --- a/tests/unit/search/test_search_studies.py +++ /dev/null @@ -1,312 +0,0 @@ -"""Test search study methods""" - -import pytest -from ga4gh.core.entity_models import Extension - -from metakb.normalizers import VICC_NORMALIZER_DATA -from metakb.query import QueryHandler - -from .utils import assert_no_match, find_and_check_study - - -def _get_normalizer_id(extensions: list[Extension]) -> str | None: - """Get normalized ID from list of extensions - - :param extensions: List of extensions - :return: Normalized concept ID if found in extensions - """ - normalizer_id = None - for ext in extensions: - if ext.name == VICC_NORMALIZER_DATA: - normalizer_id = ext.value["id"] - break - return normalizer_id - - -def assert_general_search_studies(response): - """Check that general search_studies queries return a valid response""" - len_study_id_matches = len(response.study_ids) - assert len_study_id_matches > 0 - len_studies = len(response.studies) - assert len_studies > 0 - assert len_study_id_matches == len_studies - - -@pytest.mark.asyncio(scope="module") -async def test_civic_eid2997(query_handler, civic_eid2997_study, assertion_checks): - """Test that search_studies method works correctly for CIViC EID2997""" - resp = await query_handler.search_studies(study_id=civic_eid2997_study["id"]) - assert resp.study_ids == [civic_eid2997_study["id"]] - resp_studies = [s.model_dump(exclude_none=True) for s in resp.studies] - assertion_checks(resp_studies, [civic_eid2997_study]) - assert resp.warnings == [] - - resp = await query_handler.search_studies(variation="EGFR L858R") - find_and_check_study(resp, civic_eid2997_study, assertion_checks) - - resp = await query_handler.search_studies( - variation="ga4gh:VA.S41CcMJT2bcd8R4-qXZWH1PoHWNtG2PZ" - ) - find_and_check_study(resp, civic_eid2997_study, assertion_checks) - - # genomic query - resp = await query_handler.search_studies(variation="7-55259515-T-G") - find_and_check_study(resp, civic_eid2997_study, assertion_checks) - - resp = await query_handler.search_studies(therapy="ncit:C66940") - find_and_check_study(resp, civic_eid2997_study, assertion_checks) - - resp = await query_handler.search_studies(gene="EGFR") - find_and_check_study(resp, civic_eid2997_study, assertion_checks) - - resp = await query_handler.search_studies(disease="nsclc") - find_and_check_study(resp, civic_eid2997_study, assertion_checks) - - # We should not find CIViC EID2997 using these queries - resp = await query_handler.search_studies(study_id="civic.eid:3017") - find_and_check_study(resp, civic_eid2997_study, assertion_checks, False) - - resp = await query_handler.search_studies(variation="BRAF V600E") - find_and_check_study(resp, civic_eid2997_study, assertion_checks, False) - - resp = await query_handler.search_studies(therapy="imatinib") - find_and_check_study(resp, civic_eid2997_study, assertion_checks, False) - - resp = await query_handler.search_studies(gene="BRAF") - find_and_check_study(resp, civic_eid2997_study, assertion_checks, False) - - resp = await query_handler.search_studies(disease="DOID:9253") - find_and_check_study(resp, civic_eid2997_study, assertion_checks, False) - - -@pytest.mark.asyncio(scope="module") -async def test_civic816(query_handler, civic_eid816_study, assertion_checks): - """Test that search_studies method works correctly for CIViC EID816""" - resp = await query_handler.search_studies(study_id=civic_eid816_study["id"]) - assert resp.study_ids == [civic_eid816_study["id"]] - resp_studies = [s.model_dump(exclude_none=True) for s in resp.studies] - assertion_checks(resp_studies, [civic_eid816_study]) - assert resp.warnings == [] - - # Try querying based on therapies in substitutes - resp = await query_handler.search_studies(therapy="Cetuximab") - find_and_check_study(resp, civic_eid816_study, assertion_checks) - - resp = await query_handler.search_studies(therapy="Panitumumab") - find_and_check_study(resp, civic_eid816_study, assertion_checks) - - -@pytest.mark.asyncio(scope="module") -async def test_civic9851(query_handler, civic_eid9851_study, assertion_checks): - """Test that search_studies method works correctly for CIViC EID9851""" - resp = await query_handler.search_studies(study_id=civic_eid9851_study["id"]) - assert resp.study_ids == [civic_eid9851_study["id"]] - resp_studies = [s.model_dump(exclude_none=True) for s in resp.studies] - assertion_checks(resp_studies, [civic_eid9851_study]) - assert resp.warnings == [] - - # Try querying based on therapies in components - resp = await query_handler.search_studies(therapy="Encorafenib") - find_and_check_study(resp, civic_eid9851_study, assertion_checks) - - resp = await query_handler.search_studies(therapy="Cetuximab") - find_and_check_study(resp, civic_eid9851_study, assertion_checks) - - -@pytest.mark.asyncio(scope="module") -async def test_moa_66(query_handler, moa_aid66_study, assertion_checks): - """Test that search_studies method works correctly for MOA Assertion 66""" - resp = await query_handler.search_studies(study_id=moa_aid66_study["id"]) - assert resp.study_ids == [moa_aid66_study["id"]] - resp_studies = [s.model_dump(exclude_none=True) for s in resp.studies] - assertion_checks(resp_studies, [moa_aid66_study]) - assert resp.warnings == [] - - resp = await query_handler.search_studies(variation="ABL1 Thr315Ile") - find_and_check_study(resp, moa_aid66_study, assertion_checks) - - resp = await query_handler.search_studies( - variation="ga4gh:VA.D6NzpWXKqBnbcZZrXNSXj4tMUwROKbsQ" - ) - find_and_check_study(resp, moa_aid66_study, assertion_checks) - - resp = await query_handler.search_studies(therapy="rxcui:282388") - find_and_check_study(resp, moa_aid66_study, assertion_checks) - - resp = await query_handler.search_studies(gene="ncbigene:25") - find_and_check_study(resp, moa_aid66_study, assertion_checks) - - resp = await query_handler.search_studies(disease="CML") - find_and_check_study(resp, moa_aid66_study, assertion_checks) - - # We should not find MOA Assertion 67 using these queries - resp = await query_handler.search_studies(study_id="moa.assertion:71") - find_and_check_study(resp, moa_aid66_study, assertion_checks, False) - - resp = await query_handler.search_studies(variation="BRAF V600E") - find_and_check_study(resp, moa_aid66_study, assertion_checks, False) - - resp = await query_handler.search_studies(therapy="Afatinib") - find_and_check_study(resp, moa_aid66_study, assertion_checks, False) - - resp = await query_handler.search_studies(gene="ABL2") - find_and_check_study(resp, moa_aid66_study, assertion_checks, False) - - resp = await query_handler.search_studies(disease="ncit:C2926") - find_and_check_study(resp, moa_aid66_study, assertion_checks, False) - - -@pytest.mark.asyncio(scope="module") -async def test_general_search_studies(query_handler): - """Test that queries do not return errors""" - resp = await query_handler.search_studies(variation="BRAF V600E") - assert_general_search_studies(resp) - - resp = await query_handler.search_studies(variation="EGFR L858R") - assert_general_search_studies(resp) - - resp = await query_handler.search_studies(disease="cancer") - assert_general_search_studies(resp) - - # Case: Handling therapy for single therapeutic agent / combination / substitutes - resp = await query_handler.search_studies(therapy="Cetuximab") - assert_general_search_studies(resp) - expected_therapy_id = "rxcui:318341" - for study in resp.studies: - tp = study.objectTherapeutic.root - if tp.type == "TherapeuticAgent": - assert _get_normalizer_id(tp.extensions) == expected_therapy_id - else: - therapeutics = ( - tp.components if tp.type == "CombinationTherapy" else tp.substitutes - ) - - found_expected = False - for therapeutic in therapeutics: - if _get_normalizer_id(therapeutic.extensions) == expected_therapy_id: - found_expected = True - break - assert found_expected - - resp = await query_handler.search_studies(gene="VHL") - assert_general_search_studies(resp) - - # Case: multiple concepts provided - expected_variation_id = "ga4gh:VA._8jTS8nAvWwPZGOadQuD1o-tbbTQ5g3H" - expected_disease_id = "ncit:C2926" - expected_therapy_id = "ncit:C104732" - resp = await query_handler.search_studies( - variation=expected_variation_id, - disease=expected_disease_id, - therapy=expected_therapy_id, # Single Therapeutic Agent - ) - assert_general_search_studies(resp) - - for study in resp.studies: - assert ( - study.subjectVariant.constraints[0].root.definingContext.root.id - == expected_variation_id - ) - assert ( - _get_normalizer_id(study.objectTherapeutic.root.extensions) - == expected_therapy_id - ) - assert ( - _get_normalizer_id(study.conditionQualifier.root.extensions) - == expected_disease_id - ) - - -@pytest.mark.asyncio(scope="module") -async def test_no_matches(query_handler): - """Test invalid queries""" - # invalid vrs variation prefix (digest is correct) - resp = await query_handler.search_studies( - variation="ga4gh:variation.TAARa2cxRHmOiij9UBwvW-noMDoOq2x9" - ) - assert_no_match(resp) - - # invalid id - resp = await query_handler.search_studies( - disease="ncit:C292632425235321524352435623462" - ) - assert_no_match(resp) - - # empty query - resp = await query_handler.search_studies() - assert_no_match(resp) - - # valid queries, but no matches with combination - resp = await query_handler.search_studies(variation="BRAF V600E", gene="EGFR") - assert_no_match(resp) - - -@pytest.mark.asyncio(scope="module") -async def test_paginate(query_handler: QueryHandler, normalizers): - """Test pagination parameters.""" - braf_va_id = "ga4gh:VA.Otc5ovrw906Ack087o1fhegB4jDRqCAe" - full_response = await query_handler.search_studies(variation=braf_va_id) - paged_response = await query_handler.search_studies(variation=braf_va_id, start=1) - # should be almost the same, just off by 1 - assert len(paged_response.study_ids) == len(full_response.study_ids) - 1 - assert paged_response.study_ids == full_response.study_ids[1:] - - # check that page limit > response doesn't affect response - huge_page_response = await query_handler.search_studies( - variation=braf_va_id, limit=1000 - ) - assert len(huge_page_response.study_ids) == len(full_response.study_ids) - assert huge_page_response.study_ids == full_response.study_ids - - # get last item - last_response = await query_handler.search_studies( - variation=braf_va_id, start=len(full_response.study_ids) - 1 - ) - assert len(last_response.study_ids) == 1 - assert last_response.study_ids[0] == full_response.study_ids[-1] - - # test limit - min_response = await query_handler.search_studies(variation=braf_va_id, limit=1) - assert min_response.study_ids[0] == full_response.study_ids[0] - - # test limit and start - other_min_response = await query_handler.search_studies( - variation=braf_va_id, start=1, limit=1 - ) - assert other_min_response.study_ids[0] == full_response.study_ids[1] - - # test limit of 0 - empty_response = await query_handler.search_studies(variation=braf_va_id, limit=0) - assert len(empty_response.study_ids) == 0 - - # test raises exceptions - with pytest.raises(ValueError, match="Can't start from an index of less than 0."): - await query_handler.search_studies(variation=braf_va_id, start=-1) - with pytest.raises( - ValueError, match="Can't limit results to less than a negative number." - ): - await query_handler.search_studies(variation=braf_va_id, limit=-1) - - # test default limit - limited_query_handler = QueryHandler(normalizers=normalizers, default_page_limit=1) - default_limited_response = await limited_query_handler.search_studies( - variation=braf_va_id - ) - assert len(default_limited_response.study_ids) == 1 - assert default_limited_response.study_ids[0] == full_response.study_ids[0] - - # test overrideable - less_limited_response = await limited_query_handler.search_studies( - variation=braf_va_id, limit=2 - ) - assert len(less_limited_response.study_ids) == 2 - assert less_limited_response.study_ids == full_response.study_ids[:2] - - # test default limit and skip - skipped_limited_response = await limited_query_handler.search_studies( - variation=braf_va_id, start=1 - ) - assert len(skipped_limited_response.study_ids) == 1 - assert skipped_limited_response.study_ids[0] == full_response.study_ids[1] - - limited_query_handler.driver.close() diff --git a/tests/unit/search/utils.py b/tests/unit/search/utils.py index 8eaf78c2..e513b1aa 100644 --- a/tests/unit/search/utils.py +++ b/tests/unit/search/utils.py @@ -1,33 +1,35 @@ -from metakb.schemas.api import BatchSearchStudiesService, SearchStudiesService +from metakb.schemas.api import BatchSearchStatementsService, SearchStatementsService def assert_no_match(response): - """No match assertions for queried concepts in search_studies.""" - assert response.studies == response.study_ids == [] + """No match assertions for queried concepts in search_statements.""" + assert response.statements == response.statement_ids == [] assert len(response.warnings) > 0 -def find_and_check_study( - resp: SearchStudiesService | BatchSearchStudiesService, - expected_study: dict, +def find_and_check_stmt( + resp: SearchStatementsService | BatchSearchStatementsService, + expected_stmt: dict, assertion_checks: callable, should_find_match: bool = True, ): - """Check that expected study is or is not in response""" + """Check that expected statement is or is not in response""" if should_find_match: - assert expected_study["id"] in resp.study_ids + assert expected_stmt["id"] in resp.statement_ids else: - assert expected_study["id"] not in resp.study_ids + assert expected_stmt["id"] not in resp.statement_ids - actual_study = None - for study in resp.studies: - if study.id == expected_study["id"]: - actual_study = study + actual_stmt = None + for stmt in resp.statements: + if stmt.id == expected_stmt["id"]: + actual_stmt = stmt break if should_find_match: - assert actual_study, f"Did not find study ID {expected_study['id']} in studies" - resp_studies = [actual_study.model_dump(exclude_none=True)] - assertion_checks(resp_studies, [expected_study]) + assert ( + actual_stmt + ), f"Did not find statement ID {expected_stmt['id']} in statements" + resp_stmts = [actual_stmt.model_dump(exclude_none=True)] + assertion_checks(resp_stmts, [expected_stmt]) else: - assert actual_study is None + assert actual_stmt is None diff --git a/tests/unit/transformers/test_civic_transformer_therapeutic.py b/tests/unit/transformers/test_civic_transformer_therapeutic.py index 2292455b..174179d8 100644 --- a/tests/unit/transformers/test_civic_transformer_therapeutic.py +++ b/tests/unit/transformers/test_civic_transformer_therapeutic.py @@ -27,11 +27,13 @@ async def data(normalizers): @pytest.fixture(scope="module") -def studies(civic_eid2997_study, civic_eid816_study, civic_eid9851_study): - """Create test fixture for CIViC therapeutic studies.""" - return [civic_eid2997_study, civic_eid816_study, civic_eid9851_study] +def statements( + civic_eid2997_study_stmt, civic_eid816_study_stmt, civic_eid9851_study_stmt +): + """Create test fixture for CIViC therapeutic statements.""" + return [civic_eid2997_study_stmt, civic_eid816_study_stmt, civic_eid9851_study_stmt] -def test_civic_cdm(data, studies, check_transformed_cdm): +def test_civic_cdm(data, statements, check_transformed_cdm): """Test that civic transformation works correctly.""" - check_transformed_cdm(data, studies, DATA_DIR / FILENAME) + check_transformed_cdm(data, statements, DATA_DIR / FILENAME) diff --git a/tests/unit/transformers/test_moa_transformer.py b/tests/unit/transformers/test_moa_transformer.py index e86ba3c6..f0c628da 100644 --- a/tests/unit/transformers/test_moa_transformer.py +++ b/tests/unit/transformers/test_moa_transformer.py @@ -119,8 +119,8 @@ def moa_encorafenib(encorafenib_extensions): @pytest.fixture(scope="module") -def moa_aid155_study(moa_vid145, moa_cetuximab, moa_encorafenib, moa_method): - """Create MOA AID 155 study test fixture. Uses CombinationTherapy.""" +def moa_aid155_study_stmt(moa_vid145, moa_cetuximab, moa_encorafenib, moa_method): + """Create MOA AID 155 study statement test fixture. Uses CombinationTherapy.""" return { "id": "moa.assertion:155", "type": "VariantTherapeuticResponseStudyStatement", @@ -196,11 +196,11 @@ def moa_aid155_study(moa_vid145, moa_cetuximab, moa_encorafenib, moa_method): @pytest.fixture(scope="module") -def studies(moa_aid66_study, moa_aid155_study): - """Create test fixture for MOA therapeutic studies.""" - return [moa_aid66_study, moa_aid155_study] +def statements(moa_aid66_study_stmt, moa_aid155_study_stmt): + """Create test fixture for MOA therapeutic statements.""" + return [moa_aid66_study_stmt, moa_aid155_study_stmt] -def test_moa_cdm(data, studies, check_transformed_cdm): +def test_moa_cdm(data, statements, check_transformed_cdm): """Test that moa transformation works correctly.""" - check_transformed_cdm(data, studies, TEST_TRANSFORMERS_DIR / FILENAME) + check_transformed_cdm(data, statements, TEST_TRANSFORMERS_DIR / FILENAME)