Skip to content

Commit

Permalink
Merge pull request #33 from materials-data-facility/forge-dev
Browse files Browse the repository at this point in the history
forge-dev
  • Loading branch information
jgaff authored Nov 28, 2018
2 parents 0c3be38 + 95fd6af commit b9950f8
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 33 deletions.
40 changes: 31 additions & 9 deletions mdf_forge/forge.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import re
from urllib.parse import urlparse

import globus_sdk
Expand Down Expand Up @@ -420,9 +419,6 @@ def match_source_names(self, source_names):
return self
if isinstance(source_names, str):
source_names = [source_names]
# If no version supplied, add * to each source name to match all versions
source_names = [(sn+"*" if re.search(".*_v[0-9]+", sn) is None else sn)
for sn in source_names]
# First source should be in new group and required
self.match_field(field="mdf.source_name", value=source_names[0],
required=True, new_group=True)
Expand Down Expand Up @@ -701,7 +697,7 @@ def get_dataset_version(self, source_name):
int: Version of the dataset in question
"""

hits = self.search("mdf.source_name:{}_v* AND"
hits = self.search("mdf.source_name:{} AND"
" mdf.resource_type:dataset".format(source_name),
advanced=True, limit=2)

Expand Down Expand Up @@ -1255,7 +1251,7 @@ def negate(self):
self.operator("NOT")
return self

def search(self, q=None, index=None, advanced=None, limit=None, info=False):
def search(self, q=None, index=None, advanced=None, limit=None, info=False, retries=3):
"""Execute a search and return the results.
Args:
Expand All @@ -1276,6 +1272,8 @@ def search(self, q=None, index=None, advanced=None, limit=None, info=False):
If **True**, search will return a tuple containing the results list
and other information about the query.
Default **False**.
retries (int): The number of times to retry a Search query if it fails.
Default 3.
Returns:
list (if info=False): The results.
Expand Down Expand Up @@ -1309,20 +1307,43 @@ def search(self, q=None, index=None, advanced=None, limit=None, info=False):
"limit": limit,
"offset": 0
}
res = mdf_toolbox.gmeta_pop(self.__search_client.post_search(uuid_index, qu), info=info)
tries = 0
errors = []
while True:
try:
search_res = self.__search_client.post_search(uuid_index, qu)
except globus_sdk.SearchAPIError as e:
if tries >= retries:
raise
else:
errors.append(repr(e))
except Exception as e:
if tries >= retries:
raise
else:
errors.append(repr(e))
else:
break
tries += 1
res = mdf_toolbox.gmeta_pop(search_res, info=info)
# Add additional info
if info:
res[1]["query"] = qu
res[1]["index"] = index
res[1]["index_uuid"] = uuid_index
res[1]["retries"] = tries
res[1]["errors"] = errors
return res

def aggregate(self, q=None, index=None, scroll_size=SEARCH_LIMIT):
def aggregate(self, q=None, index=None, retries=1, scroll_size=SEARCH_LIMIT):
"""Gather all results that match a specific query
Args:
q (str): The query to execute. Defaults to the current query, if any.
There must be some query to execute.
index (str): The Globus Search index to search on. Required.
retries (int): The number of times to retry a Search query if it fails.
Default 1.
scroll_size (int): Maximum number of records requested per request.
Returns:
Expand Down Expand Up @@ -1366,7 +1387,8 @@ def aggregate(self, q=None, index=None, scroll_size=SEARCH_LIMIT):
while True:
query = "(" + q + ') AND (mdf.scroll_id:>=%d AND mdf.scroll_id:<%d)' % (
scroll_pos, scroll_pos+scroll_width)
results, info = self.search(query, index=index, advanced=True, info=True)
results, info = self.search(query, index=index, advanced=True,
info=True, retries=retries)

# Check to make sure that all the matching records were returned
if info["total_query_matches"] <= len(results):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='mdf_forge',
version='0.6.4',
version='0.6.5',
packages=['mdf_forge'],
description='Materials Data Facility python package',
long_description=("Forge is the Materials Data Facility Python package"
Expand Down
47 changes: 24 additions & 23 deletions tests/test_forge.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,11 @@ def test_query_search(capsys):
# Check default limits
res5 = q.search("Al", index="mdf")
assert len(res5) == 10
res6 = q.search("mdf.source_name:nist_xps_db*", advanced=True, index="mdf")
res6 = q.search("mdf.source_name:nist_xps_db", advanced=True, index="mdf")
assert len(res6) == 10000

# Check limit correction
res7 = q.search("mdf.source_name:nist_xps_db*", advanced=True, index="mdf", limit=20000)
res7 = q.search("mdf.source_name:nist_xps_db", advanced=True, index="mdf", limit=20000)
assert len(res7) == 10000

# Test index translation
Expand Down Expand Up @@ -184,29 +184,29 @@ def test_query_aggregate(capsys):
assert "Error: No index specified" in out

# Basic aggregation
res1 = q.aggregate("mdf.source_name:nist_xps_db*", index="mdf")
res1 = q.aggregate("mdf.source_name:nist_xps_db", index="mdf")
assert len(res1) > 10000
assert isinstance(res1[0], dict)

# Multi-dataset aggregation
res2 = q.aggregate("(mdf.source_name:nist_xps_db* OR mdf.source_name:khazana_vasp*)",
res2 = q.aggregate("(mdf.source_name:nist_xps_db OR mdf.source_name:khazana_vasp)",
index="mdf")
assert len(res2) > 10000
assert len(res2) > len(res1)

# Unnecessary aggregation fallback to .search()
# Check success in Coveralls
assert len(q.aggregate("mdf.source_name:khazana_vasp*")) < 10000
assert len(q.aggregate("mdf.source_name:khazana_vasp")) < 10000


def test_query_chaining():
q = forge.Query(query_search_client)
q.field("source_name", "cip*")
q.field("source_name", "cip")
q.and_join()
q.field("elements", "Al")
res1 = q.search(limit=10000, index="mdf")
res2 = (forge.Query(query_search_client)
.field("source_name", "cip*")
.field("source_name", "cip")
.and_join()
.field("elements", "Al")
.search(limit=10000, index="mdf"))
Expand Down Expand Up @@ -399,7 +399,7 @@ def test_forge_alt_clients():
def test_forge_match_field():
f = forge.Forge(index="mdf")
# Basic usage
f.match_field("mdf.source_name", "khazana_vasp*")
f.match_field("mdf.source_name", "khazana_vasp")
res1 = f.search()
assert check_field(res1, "mdf.source_name", "khazana_vasp") == 0
# Check that query clears
Expand All @@ -417,7 +417,8 @@ def test_forge_exclude_field():
# Basic usage
f.exclude_field("material.elements", "Al")
f.exclude_field("", "")
f.match_field("mdf.source_name", "ab_initio_solute_database*")
f.match_field("mdf.source_name", "ab_initio_solute_database")
f.match_field("mdf.resource_type", "record")
res1 = f.search()
assert check_field(res1, "material.elements", "Al") == -1

Expand Down Expand Up @@ -509,7 +510,7 @@ def test_forge_match_source_names():
def test_forge_match_ids():
# Get a couple IDs
f = forge.Forge(index="mdf")
res0 = f.search("mdf.source_name:khazana_vasp*", advanced=True, limit=2)
res0 = f.search("mdf.source_name:khazana_vasp", advanced=True, limit=2)
id1 = res0[0]["mdf"]["mdf_id"]
id2 = res0[1]["mdf"]["mdf_id"]

Expand Down Expand Up @@ -653,7 +654,7 @@ def test_forge_search(capsys):
assert len(res4) == 3

# Check reset_query
f.match_field("mdf.source_name", "ta_melting*")
f.match_field("mdf.source_name", "ta_melting")
res5 = f.search(reset_query=False)
res6 = f.search()
assert all([r in res6 for r in res5]) and all([r in res5 for r in res6])
Expand Down Expand Up @@ -701,19 +702,19 @@ def test_forge_fetch_datasets_from_results():
# Get some results
f = forge.Forge(index="mdf")
# Record from OQMD
res01 = f.search("mdf.source_name:oqmd* AND mdf.resource_type:record", advanced=True, limit=1)
res01 = f.search("mdf.source_name:oqmd AND mdf.resource_type:record", advanced=True, limit=1)
# Record from OQMD with info
res02 = f.search("mdf.source_name:oqmd* AND mdf.resource_type:record",
res02 = f.search("mdf.source_name:oqmd AND mdf.resource_type:record",
advanced=True, limit=1, info=True)
# Records from JANAF
res03 = f.search("mdf.source_name:khazana_vasp* AND mdf.resource_type:record",
res03 = f.search("mdf.source_name:khazana_vasp AND mdf.resource_type:record",
advanced=True, limit=2)
# Dataset for NIST XPS DB
res04 = f.search("mdf.source_name:nist_xps_db* AND mdf.resource_type:dataset", advanced=True)
res04 = f.search("mdf.source_name:nist_xps_db AND mdf.resource_type:dataset", advanced=True)

# Get the correct dataset entries
oqmd = f.search("mdf.source_name:oqmd* AND mdf.resource_type:dataset", advanced=True)[0]
khazana_vasp = f.search("mdf.source_name:khazana_vasp* AND mdf.resource_type:dataset",
oqmd = f.search("mdf.source_name:oqmd AND mdf.resource_type:dataset", advanced=True)[0]
khazana_vasp = f.search("mdf.source_name:khazana_vasp AND mdf.resource_type:dataset",
advanced=True)[0]

# Fetch single dataset
Expand Down Expand Up @@ -749,7 +750,7 @@ def test_forge_aggregate():
# And returns results
# And respects the reset_query arg
f = forge.Forge(index="mdf")
f.match_field("mdf.source_name", "nist_xps_db*")
f.match_field("mdf.source_name", "nist_xps_db")
res1 = f.aggregate(reset_query=False, index="mdf")
assert len(res1) > 10000
assert check_field(res1, "mdf.source_name", "nist_xps_db") == 0
Expand Down Expand Up @@ -911,10 +912,10 @@ def test_forge_http_stream(capsys):

def test_forge_chaining():
f = forge.Forge(index="mdf")
f.match_field("source_name", "cip*")
f.match_field("source_name", "cip")
f.match_field("material.elements", "Al")
res1 = f.search()
res2 = f.match_field("source_name", "cip*").match_field("material.elements", "Al").search()
res2 = f.match_field("source_name", "cip").match_field("material.elements", "Al").search()
assert all([r in res2 for r in res1]) and all([r in res1 for r in res2])


Expand All @@ -929,11 +930,11 @@ def test_forge_show_fields():
def test_forge_anonymous(capsys):
f = forge.Forge(anonymous=True)
# Test search
assert len(f.search("mdf.source_name:ab_initio_solute_database*",
assert len(f.search("mdf.source_name:ab_initio_solute_database",
advanced=True, limit=300)) == 300

# Test aggregation
assert len(f.aggregate("mdf.source_name:nist_xps_db*")) > 10000
assert len(f.aggregate("mdf.source_name:nist_xps_db")) > 10000

# Error on auth-only functions
# http_download
Expand All @@ -956,7 +957,7 @@ def test_forge_anonymous(capsys):
def test_get_dataset_version():
# Get the version number of the OQMD
f = forge.Forge()
hits = f.search('mdf.source_name:oqmd_v* AND mdf.resource_type:dataset',
hits = f.search('mdf.source_name:oqmd AND mdf.resource_type:dataset',
advanced=True, limit=1)
assert hits[0]['mdf']['version'] == f.get_dataset_version('oqmd')

Expand Down

0 comments on commit b9950f8

Please sign in to comment.