Skip to content

Commit

Permalink
refactor: add mindate, maxdate args in query function
Browse files Browse the repository at this point in the history
  • Loading branch information
EverVino committed Feb 1, 2024
1 parent 73cf191 commit 98fad0d
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 134 deletions.
126 changes: 79 additions & 47 deletions src/pymedx/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import itertools
import xml.etree.ElementTree as xml

from typing import Union
from typing import Any, Dict, Iterable, List, Union, cast

import requests

Expand All @@ -14,54 +14,63 @@
BASE_URL = "https://eutils.ncbi.nlm.nih.gov"


class PubMed(object):
class PubMed:
"""Wrapper around the PubMed API."""

def __init__(
self: object,
self,
tool: str = "my_tool",
email: str = "[email protected]",
) -> None:
"""Initialization of the object.
Parameters
----------
Parameters:
- tool String, name of the tool that is executing the query.
This parameter is not required but kindly requested by
PMC (PubMed Central).
- email String, email of the user of the tool. This parameter
is not required but kindly requested by PMC (PubMed Central).
Returns
-------
Returns:
- None
"""

# Store the input parameters
self.tool = tool
self.email = email

# Keep track of the rate limit
self._rateLimit = 3
self._requestsMade = []

self._rateLimit: int = 3
self._requestsMade: List[datetime.datetime] = []
self.parameters: Dict[str, Union[str, int, List[str]]]
# Define the standard / default query parameters
self.parameters = {"tool": tool, "email": email, "db": "pubmed"}

def query(self: object, query: str, max_results: int = 100):
def query(
self,
query: str,
min_date: str,
max_date: str,
max_results: int = 100,
) -> Iterable[Union[PubMedArticle, PubMedBookArticle]]:
"""Method that executes a query agains the GraphQL schema, automatically
inserting the PubMed data loader.
Parameters
----------
Parameters:
- query String, the GraphQL query to execute against the schema.
Returns
-------
Returns:
- result ExecutionResult, GraphQL object that contains the result
in the "data" attribute.
"""

# Retrieve the article IDs for the query
article_ids = self._getArticleIds(query=query, max_results=max_results)
article_ids = self._getArticleIds(
query=query,
min_date=min_date,
max_date=max_date,
max_results=max_results,
)

# Get the articles themselves
articles = list(
Expand All @@ -74,17 +83,16 @@ def query(self: object, query: str, max_results: int = 100):
# Chain the batches back together and return the list
return itertools.chain.from_iterable(articles)

def getTotalResultsCount(self: object, query: str) -> int:
def getTotalResultsCount(self, query: str) -> int:
"""Helper method that returns the total number of results that match the query.
Parameters
----------
Parameters:
- query String, the query to send to PubMed
Returns
-------
Returns:
- total_results_count Int, total number of results for the query in PubMed
"""

# Get the default parameters
parameters = self.parameters.copy()

Expand All @@ -93,7 +101,7 @@ def getTotalResultsCount(self: object, query: str) -> int:
parameters["retmax"] = 1

# Make the request (request a single article ID for this search)
response = self._get(
response: requests.models.Response = self._get(
url="/entrez/eutils/esearch.fcgi", parameters=parameters
)

Expand All @@ -108,10 +116,10 @@ def getTotalResultsCount(self: object, query: str) -> int:
def _exceededRateLimit(self) -> bool:
"""Helper method to check if we've exceeded the rate limit.
Returns
-------
Returns:
- exceeded Bool, Whether or not the rate limit is exceeded.
"""

# Remove requests from the list that are longer than 1 second ago
self._requestsMade = [
requestTime
Expand All @@ -124,34 +132,37 @@ def _exceededRateLimit(self) -> bool:
return len(self._requestsMade) > self._rateLimit

def _get(
self: object, url: str, parameters: dict, output: str = "json"
) -> Union[dict, str]:
self,
url: str,
parameters: Dict[Any, Any] = dict(),
output: str = "json",
) -> Union[str, requests.models.Response]:
"""Generic helper method that makes a request to PubMed.
Parameters
----------
Parameters:
- url Str, last part of the URL that is requested (will
be combined with the base url)
- parameters Dict, parameters to use for the request
- output Str, type of output that is requested (defaults to
JSON but can be used to retrieve XML)
Returns
-------
Returns:
- response Dict / str, if the response is valid JSON it will
be parsed before returning, otherwise a string is
returend
"""

# Make sure the rate limit is not exceeded
while self._exceededRateLimit():
pass

# Set the response mode
parameters["retmode"] = output

if parameters:
parameters["retmode"] = output

# Make the request to PubMed
response = requests.get(f"{BASE_URL}{url}", params=parameters)

# Check for any errors
response.raise_for_status()

Expand All @@ -164,17 +175,18 @@ def _get(
else:
return response.text

def _getArticles(self: object, article_ids: list) -> list:
def _getArticles(
self, article_ids: List[str]
) -> Iterable[Union[PubMedArticle, PubMedBookArticle]]:
"""Helper method that batches a list of article IDs and retrieves the content.
Parameters
----------
Parameters:
- article_ids List, article IDs.
Returns
-------
Returns:
- articles List, article objects.
"""

# Get the default parameters
parameters = self.parameters.copy()
parameters["id"] = article_ids
Expand All @@ -195,18 +207,23 @@ def _getArticles(self: object, article_ids: list) -> list:
for book in root.iter("PubmedBookArticle"):
yield PubMedBookArticle(xml_element=book)

def _getArticleIds(self: object, query: str, max_results: int) -> list:
def _getArticleIds(
self,
query: str,
min_date: str,
max_date: str,
max_results: int,
) -> List[str]:
"""Helper method to retrieve the article IDs for a query.
Parameters
----------
Parameters:
- query Str, query to be executed against the PubMed database.
- max_results Int, the maximum number of results to retrieve.
Returns
-------
Returns:
- article_ids List, article IDs as a list.
"""

# Create a placeholder for the retrieved IDs
article_ids = []

Expand All @@ -216,16 +233,29 @@ def _getArticleIds(self: object, query: str, max_results: int) -> list:
# Add specific query parameters
parameters["term"] = query
parameters["retmax"] = 50000
parameters["datetype"] = "edat"
parameters["mindate"] = min_date
parameters["maxdate"] = max_date

retmax: int = cast(int, parameters["retmax"])
# Calculate a cut off point based on the max_results parameter
if max_results < parameters["retmax"]:
if max_results < retmax:
parameters["retmax"] = max_results

# Make the first request to PubMed
response = self._get(
url="/entrez/eutils/esearch.fcgi", parameters=parameters
new_url = (
"/entrez/eutils/esearch.fcgi?"
f"db={parameters['db']}&"
f"term={parameters['term']}&"
f"retmax={parameters['retmax']}&"
f"datetype={parameters['datetype']}&"
f"mindate={parameters['mindate']}&"
f"maxdate={parameters['maxdate']}&"
f"retmode=json"
)

# Make the first request to PubMed
response: requests.models.Response = self._get(url=new_url)

# Add the retrieved IDs to the list
article_ids += response.get("esearchresult", {}).get("idlist", [])

Expand All @@ -245,7 +275,9 @@ def _getArticleIds(self: object, query: str, max_results: int) -> list:
and retrieved_count < max_results
):
# Calculate a cut off point based on the max_results parameter
if (max_results - retrieved_count) < parameters["retmax"]:
if (max_results - retrieved_count) < cast(
int, parameters["retmax"]
):
parameters["retmax"] = max_results - retrieved_count

# Start the collection from the number of already retrieved articles
Expand Down
Loading

0 comments on commit 98fad0d

Please sign in to comment.