Skip to content

Commit

Permalink
1. Added keyword search functionality.
Browse files Browse the repository at this point in the history
2. Added filters to count and download.
3. Fixed aggregations not returning proper data.
  • Loading branch information
AyaanKakkar committed Sep 23, 2024
1 parent f80cd9b commit 836fad2
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 95 deletions.
21 changes: 18 additions & 3 deletions src/graphql/resolvers/count_resolver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ...config.es import es
from ...config.settings import settings
from src.graphql.models.annotation_model import FilterArgs, PageArgs
from .helper_resolver import IDs_query, annotation_query, chromosome_query, convert_hits, gene_query, get_aggregation_query, rsID_query, rsIDs_query
from src.graphql.models.annotation_model import FilterArgs
from .helper_resolver import IDs_query, annotation_query, chromosome_query, gene_query, keyword_query, rsID_query, rsIDs_query


async def get_annotations_count():
Expand Down Expand Up @@ -106,4 +106,19 @@ async def count_by_gene(gene:str, filter_args=FilterArgs):
)
return resp['count']

return 0
return 0

async def count_by_keyword(keyword: str):
"""
Query for getting count of annotation by keyword
Params:
keyword: Keyword to search
Returns: integer for count of annotations
"""
resp = await es.count(
index = settings.ES_INDEX,
query = keyword_query(keyword)
)
return resp['count']
146 changes: 97 additions & 49 deletions src/graphql/resolvers/helper_resolver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import inspect
import json
from typing import Dict
from src.graphql.gene_pos import get_pos_from_gene_id, map_gene, chromosomal_location_dic
from src.graphql.models.snp_model import ScrollSnp, Snp, SnpAggs
from src.graphql.models.annotation_model import AggregationItem, Bucket, DocCount, Histogram
from src.graphql.models.annotation_model import AggregationItem, Bucket, DocCount, FilterArgs, Histogram

from src.utils import clean_field_name

Expand Down Expand Up @@ -224,61 +226,107 @@ def gene_query(gene, filter_args=None):

return None

async def get_aggregation_query(es_fields: list[str], histogram: Histogram):
def keyword_query(keyword: str):
"""
Query for getting aggregates of annotation
Query for getting annotation by keyword
Params: es_fields: List of fields to be returned in elasticsearch query
histogram: Histogram object for histogram aggregation
Params: keyword: Keyword for search
Returns: Query for elasticsearch
"""
results = dict()
for field in es_fields:

results[f'{field}_doc_count'] = {
"filter" : {
"exists": {
"field": field
}
}
}

results[f'{field}_min'] = {
"min": {
"field": "pos"
searchable_fields = []
with open('./data/anno_tree.json') as f:
data = json.load(f)
searchable_fields = [elt['name'] for elt in data if data.get('keyword_searchable', False)]

query = {
"multi_match": {
"query": keyword,
"fields": searchable_fields
}
}
}

results[f'{field}_max'] = {
"max": {
"field": "pos"
}
}
return query

results[f'{field}_frequency'] = {
"terms": {
"field": "pos",
"min_doc_count": 0,
"size": 20
}
}

results[f'{field}_missing'] = {
"missing": {
"field": "pos"
}
}

results[f'{field}_histogram'] = {
"histogram": {
"field": "pos",
"interval": histogram.interval,
"extended_bounds": {
"min": histogram.min,
"max": histogram.max
}
}
}
async def get_aggregation_query(aggregation_fields: list[tuple[str, list[str]]], histogram: Histogram):
"""
Query for getting aggregates of annotation
Params: aggregation_fields: List of fields for aggregation, along with their subfields
histogram: Histogram object for histogram aggregation
Returns: Query for elasticsearch
"""
results = dict()
for field, subfields in aggregation_fields:

# Check the type of the field. If it is a string, then we have to add .keyword to the field name while querying missing and frequency
# Using the pydantic model Snp, we can check the type of the field
is_text_field = inspect.get_annotations(Snp)[field] == str
textual_suffix = '.keyword' if is_text_field else ''

for subfield in subfields:
if subfield == 'doc_count':
results[f'{field}_doc_count'] = {
"filter" : {
"exists": {
"field": field
}
}
}

elif subfield == 'min':
results[f'{field}_min'] = {
"min": {
"field": field
}
}

elif subfield == 'max':
results[f'{field}_max'] = {
"max": {
"field": field
}
}

elif subfield == 'frequency':
results[f'{field}_frequency'] = {
"terms": {
"field": field + textual_suffix,
"min_doc_count": 0,
"size": 20
}
}

elif subfield == 'missing':
results[f'{field}_missing'] = {
"missing": {
"field": field + textual_suffix
}
}

elif subfield == 'histogram':
results[f'{field}_histogram'] = {
"histogram": {
"field": field,
"interval": histogram.interval,
"extended_bounds": {
"min": histogram.min,
"max": histogram.max
}
}
}

return results


def get_default_aggregation_fields(es_fields: list[str]):
"""
Get default aggregation fields for elasticsearch query
return results
Params: es_fields: List of fields to be returned in elasticsearch query
Returns: List of fields for aggregation
"""
return [(field, ['doc_count']) for field in es_fields]
65 changes: 51 additions & 14 deletions src/graphql/resolvers/snp_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from src.config.settings import settings
from src.graphql.resolvers.download_resolver import download_annotations
from src.graphql.models.annotation_model import FilterArgs, Histogram, PageArgs, QueryType
from src.graphql.resolvers.helper_resolver import IDs_query, annotation_query, chromosome_query, convert_aggs, convert_hits, convert_scroll_hits, gene_query, get_aggregation_query, rsID_query, rsIDs_query
from src.graphql.resolvers.helper_resolver import IDs_query, annotation_query, chromosome_query, convert_aggs, convert_hits, convert_scroll_hits, gene_query, get_aggregation_query, get_default_aggregation_fields, keyword_query, rsID_query, rsIDs_query


async def query_return(query_type, es_fields, resp):
Expand Down Expand Up @@ -30,12 +30,13 @@ async def query_return(query_type, es_fields, resp):
return results


async def get_annotations(es_fields: list[str], query_type: str, histogram=Histogram):
async def get_annotations(es_fields: list[str], query_type: str, aggregation_fields: list[tuple[str, list[str]]]=None, histogram=Histogram):
"""
Query for getting all annotations, no filter, size 20
Params: es_fields: List of fields to be returned in elasticsearch query
query_type: Type of query to be executed
aggregation_fields: List of fields for aggregation, along with their subfields
histogram: Histogram object for aggregation query
Returns: List of Snps
Expand All @@ -44,7 +45,7 @@ async def get_annotations(es_fields: list[str], query_type: str, histogram=Histo
index = settings.ES_INDEX,
source = es_fields,
query = annotation_query(),
aggs = await get_aggregation_query(es_fields, histogram) if query_type == QueryType.AGGS else None,
aggs = await get_aggregation_query(aggregation_fields or get_default_aggregation_fields(es_fields), histogram) if query_type == QueryType.AGGS else None,
size = 20,
scroll = '2m' if query_type == QueryType.DOWNLOAD else None
)
Expand All @@ -68,7 +69,7 @@ async def scroll_annotations_(scroll_id: str):
return results


async def search_by_chromosome(es_fields: list[str], chr: str, start: int, end: int, query_type: str, page_args=PageArgs, filter_args=FilterArgs,
async def search_by_chromosome(es_fields: list[str], chr: str, start: int, end: int, query_type: str, aggregation_fields: list[tuple[str, list[str]]]=None, page_args=PageArgs, filter_args=FilterArgs,
histogram=Histogram):
"""
Query for getting annotation by chromosome with start and end range of pos
Expand All @@ -78,6 +79,7 @@ async def search_by_chromosome(es_fields: list[str], chr: str, start: int, end:
start: Start position
end: End position
query_type: Type of query to be executed
aggregation_fields: List of fields for aggregation, along with their subfields
page_args: PageArgs object for pagination
filter_args: FilterArgs object for field exists filter
histogram: Histogram object for aggregation query
Expand All @@ -96,20 +98,21 @@ async def search_by_chromosome(es_fields: list[str], chr: str, start: int, end:
from_= page_args.from_ if (query_type != QueryType.DOWNLOAD and query_type != QueryType.SCROLL) else None,
size = page_args.size,
query = chromosome_query(chr, start, end, filter_args),
aggs = await get_aggregation_query(es_fields, histogram) if query_type == QueryType.AGGS else None,
aggs = await get_aggregation_query(aggregation_fields or get_default_aggregation_fields(es_fields), histogram) if query_type == QueryType.AGGS else None,
scroll = '2m' if (query_type == QueryType.DOWNLOAD or query_type == QueryType.SCROLL) else None
)

return await query_return(query_type, es_fields, resp)


async def search_by_rsID(es_fields: list[str], rsID:str, query_type: str, page_args=PageArgs, filter_args=FilterArgs, histogram=Histogram):
async def search_by_rsID(es_fields: list[str], rsID:str, query_type: str, aggregation_fields: list[tuple[str, list[str]]]=None, page_args=PageArgs, filter_args=FilterArgs, histogram=Histogram):
"""
Query for getting annotation by rsID
Params: es_fields: List of fields to be returned in elasticsearch query
rsID: rsID of snp
query_type: Type of query to be executed
aggregation_fields: List of fields for aggregation, along with their subfields
page_args: PageArgs object for pagination
filter_args: FilterArgs object for field exists filter
histogram: Histogram object for aggregation query
Expand All @@ -128,20 +131,21 @@ async def search_by_rsID(es_fields: list[str], rsID:str, query_type: str, page_a
from_= page_args.from_ if (query_type != QueryType.DOWNLOAD and query_type != QueryType.SCROLL) else None,
size = page_args.size,
query = rsID_query(rsID, filter_args),
aggs = await get_aggregation_query(es_fields, histogram) if query_type == QueryType.AGGS else None,
aggs = await get_aggregation_query(aggregation_fields or get_default_aggregation_fields(es_fields), histogram) if query_type == QueryType.AGGS else None,
scroll = '2m' if (query_type == QueryType.DOWNLOAD or query_type == QueryType.SCROLL) else None
)

return await query_return(query_type, es_fields, resp)


async def search_by_rsIDs(es_fields: list[str], rsIDs: list[str], query_type: str, page_args=PageArgs, filter_args=FilterArgs, histogram=Histogram):
async def search_by_rsIDs(es_fields: list[str], rsIDs: list[str], query_type: str, aggregation_fields: list[tuple[str, list[str]]]=None, page_args=PageArgs, filter_args=FilterArgs, histogram=Histogram):
"""
Query for getting annotation by list of rsIDs
Params: es_fields: List of fields to be returned in elasticsearch query
rsIDs: List of rsIDs of snps
query_type: Type of query to be executed
aggregation_fields: List of fields for aggregation, along with their subfields
page_args: PageArgs object for pagination
filter_args: FilterArgs object for field exists filter
histogram: Histogram object for aggregation query
Expand All @@ -160,21 +164,22 @@ async def search_by_rsIDs(es_fields: list[str], rsIDs: list[str], query_type: st
from_= page_args.from_ if (query_type != QueryType.DOWNLOAD and query_type != QueryType.SCROLL) else None,
size = page_args.size,
query = rsIDs_query(rsIDs, filter_args),
aggs = await get_aggregation_query(es_fields, histogram) if query_type == QueryType.AGGS else None,
aggs = await get_aggregation_query(aggregation_fields or get_default_aggregation_fields(es_fields), histogram) if query_type == QueryType.AGGS else None,
scroll = '2m' if (query_type == QueryType.DOWNLOAD or query_type == QueryType.SCROLL) else None
)

return await query_return(query_type, es_fields, resp)


# query for VCF file
async def search_by_IDs(es_fields: list[str], ids: list[str], query_type: str, page_args=PageArgs, filter_args=FilterArgs, histogram=Histogram):
async def search_by_IDs(es_fields: list[str], ids: list[str], query_type: str, aggregation_fields: list[tuple[str, list[str]]]=None, page_args=PageArgs, filter_args=FilterArgs, histogram=Histogram):
"""
Query for getting annotation by IDs
Params: es_fields: List of fields to be returned in elasticsearch query
ids: List of IDs of snps
query_type: Type of query to be executed
aggregation_fields: List of fields for aggregation, along with their subfields
page_args: PageArgs object for pagination
filter_args: FilterArgs object for field exists filter
histogram: Histogram object for aggregation query
Expand All @@ -193,21 +198,22 @@ async def search_by_IDs(es_fields: list[str], ids: list[str], query_type: str, p
from_= page_args.from_ if (query_type != QueryType.DOWNLOAD and query_type != QueryType.SCROLL) else None,
size = page_args.size,
query = IDs_query(ids, filter_args),
aggs = await get_aggregation_query(es_fields, histogram) if query_type == QueryType.AGGS else None,
aggs = await get_aggregation_query(aggregation_fields or get_default_aggregation_fields(es_fields), histogram) if query_type == QueryType.AGGS else None,
scroll = '2m' if (query_type == QueryType.DOWNLOAD or query_type == QueryType.SCROLL) else None
)

return await query_return(query_type, es_fields, resp)


async def search_by_gene(es_fields: list[str], gene:str, query_type: str, page_args=PageArgs, filter_args=FilterArgs, histogram=Histogram):
async def search_by_gene(es_fields: list[str], gene:str, query_type: str, aggregation_fields: list[tuple[str, list[str]]]=None, page_args=PageArgs, filter_args=FilterArgs, histogram=Histogram):
"""
Query for getting annotation by gene product
Params: es_fields: List of fields to be returned in elasticsearch query
gene: Gene product
query_type: Type of query to be executed
page_args: PageArgs object for pagination
aggregation_fields: List of fields for aggregation, along with their subfields
filter_args: FilterArgs object for field exists filter
histogram: Histogram object for aggregation query
Expand All @@ -228,8 +234,39 @@ async def search_by_gene(es_fields: list[str], gene:str, query_type: str, page_a
from_= page_args.from_ if (query_type != QueryType.DOWNLOAD and query_type != QueryType.SCROLL) else None,
size = page_args.size,
query = query,
aggs = await get_aggregation_query(es_fields, histogram) if query_type == QueryType.AGGS else None,
aggs = await get_aggregation_query(aggregation_fields or get_default_aggregation_fields(es_fields), histogram) if query_type == QueryType.AGGS else None,
scroll = '2m' if (query_type == QueryType.DOWNLOAD or query_type == QueryType.SCROLL) else None
)

return await query_return(query_type, es_fields, resp)
return await query_return(query_type, es_fields, resp)

async def search_by_keyword(es_fields: list[str], keyword: str, query_type: str, aggregation_fields: list[tuple[str, list[str]]]=None, page_args=PageArgs, histogram=Histogram):
"""
Query for getting annotation by keyword
Params: es_fields: List of fields to be returned in elasticsearch query
keyword: Keyword to be searched
query_type: Type of query to be executed
aggregation_fields: List of fields for aggregation, along with their subfields
page_args: PageArgs object for pagination
histogram: Histogram object for aggregation query
Returns: List of Snps
"""
if page_args is None:
page_args = PageArgs

if histogram is None:
histogram = Histogram

resp = await es.search(
index = settings.ES_INDEX,
source = es_fields,
from_= page_args.from_ if (query_type != QueryType.DOWNLOAD and query_type != QueryType.SCROLL) else None,
size = page_args.size,
query = keyword_query(keyword),
aggs = await get_aggregation_query(aggregation_fields or get_default_aggregation_fields(es_fields), histogram) if query_type == QueryType.AGGS else None,
scroll = '2m' if (query_type == QueryType.DOWNLOAD or query_type == QueryType.SCROLL) else None
)

return await query_return(query_type, es_fields, resp)
Loading

0 comments on commit 836fad2

Please sign in to comment.