Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

global: add pre-commit with ruff #83

Merged
merged 1 commit into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/push-master.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,3 @@ jobs:
needs: [python2_tests, python3_tests]
uses: ./.github/workflows/bump-and-publish.yml
secrets: inherit


17 changes: 17 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- id: fix-byte-order-marker
- id: mixed-line-ending
- id: name-tests-test
args: [ --pytest-test-first ]
exclude: '^(?!factories/)'
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.6
hooks:
- id: ruff
args: [ --fix ]
4 changes: 2 additions & 2 deletions inspire_matcher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from __future__ import absolute_import, division, print_function

from .api import match # noqa: F401
from .ext import InspireMatcher # noqa: F401
from inspire_matcher.api import match # noqa: F401
from inspire_matcher.ext import InspireMatcher # noqa: F401

__version__ = "9.0.29"
41 changes: 27 additions & 14 deletions inspire_matcher/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,24 @@
from __future__ import absolute_import, division, print_function

from flask import current_app
from six import string_types
from werkzeug.utils import import_string

from invenio_search import current_search_client as es
from invenio_search.utils import prefix_index
from six import string_types
from werkzeug.utils import import_string

from .core import compile
from inspire_matcher.core import compile


def _get_validator(validator_param):

if callable(validator_param):
return validator_param

try:
validator = import_string(validator_param)
except (KeyError, ImportError, AttributeError):
current_app.logger.debug('No validator provided. Falling back to the default validator.')
current_app.logger.debug(
'No validator provided. Falling back to the default validator.'
)
validator = import_string('inspire_matcher.validators:default_validator')

return validator
Expand All @@ -56,7 +56,9 @@ def match(record, config=None):
out which record a reference should be pointing to.
"""
if config is None:
current_app.logger.debug('No configuration provided. Falling back to the default configuration.')
current_app.logger.debug(
'No configuration provided. Falling back to the default configuration.'
)
config = current_app.config['MATCHER_DEFAULT_CONFIGURATION']

try:
Expand All @@ -72,11 +74,17 @@ def match(record, config=None):
query_config['_source'] = source
match_deleted = config.get('match_deleted', False)
collections = config.get('collections')
if not (collections is None or (
isinstance(collections, (list, tuple)) and
all(isinstance(collection, string_types) for collection in collections)
)):
raise ValueError('Malformed collections. Expected a list of strings bug got: %s' % repr(collections))
if not (
collections is None
or (
isinstance(collections, (list, tuple))
and all(isinstance(collection, string_types) for collection in collections)
)
):
raise ValueError(
'Malformed collections. Expected a list of strings bug got: %s'
% repr(collections)
)

for i, step in enumerate(algorithm):
try:
Expand All @@ -95,9 +103,14 @@ def match(record, config=None):

for j, query in enumerate(queries):
try:
body = compile(query, record, collections=collections, match_deleted=match_deleted)
body = compile(
query, record, collections=collections, match_deleted=match_deleted
)
except Exception as e:
raise ValueError('Malformed query. Query %d of step %d does not compile: %s.' % (j, i, repr(e)))
raise ValueError(
'Malformed query. Query %d of step %d does not compile: %s.'
% (j, i, repr(e))
)

if not body:
continue
Expand Down
6 changes: 5 additions & 1 deletion inspire_matcher/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@
"validator": "inspire_matcher.validators:cds_identifier_validator",
},
],
"source": ["control_number", "external_system_identifiers", "persistent_identifiers"],
"source": [
"control_number",
"external_system_identifiers",
"persistent_identifiers",
],
"doc_type": "hep",
"index": "records-hep",
}
Expand Down
103 changes: 58 additions & 45 deletions inspire_matcher/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def _compile_filters(query, collections, match_deleted):
}

if collections:
result['query']['bool']['filter']['bool']['should'] = _compile_collections(collections)
result['query']['bool']['filter']['bool']['should'] = _compile_collections(
collections
)
if not match_deleted:
result['query']['bool']['filter']['bool']['must_not'] = {
'match': {
Expand Down Expand Up @@ -85,21 +87,31 @@ def _compile_inner(query, record):


def _compile_collections(collections):
return [{
'match': {
'_collections': collection,
},
} for collection in collections]
return [
{
'match': {
'_collections': collection,
},
}
for collection in collections
]


def _compile_exact(query, record):
if 'match' in query:
query['path'] = query.get('path', query['match'])
warnings.warn('The "match" key is deprecated. Use "path" instead.', DeprecationWarning)
warnings.warn(
'The "match" key is deprecated. Use "path" instead.', DeprecationWarning,
stacklevel=1,
)

if 'search' in query:
query['search_path'] = query.get('search_path', query['search'])
warnings.warn('The "search" key is deprecated. Use "search_path" instead.', DeprecationWarning)
warnings.warn(
'The "search" key is deprecated. Use "search_path" instead.',
DeprecationWarning,
stacklevel=1,
)

path, search_path = query['path'], query['search_path']

Expand All @@ -116,11 +128,13 @@ def _compile_exact(query, record):
}

for value in values:
result['query']['bool']['should'].append({
'match': {
search_path: value,
},
})
result['query']['bool']['should'].append(
{
'match': {
search_path: value,
},
}
)

return result

Expand Down Expand Up @@ -151,21 +165,23 @@ def _compile_fuzzy(query, record):
if '.' in path:
raise ValueError('the "path" key can\'t contain dots')
# TODO: This query should be refined instead of relying on validation to filter out irrelevant results.
result['query']['dis_max']['queries'].append({
'more_like_this': {
'boost': boost,
'like': [
{
'doc': {
path: values,
result['query']['dis_max']['queries'].append(
{
'more_like_this': {
'boost': boost,
'like': [
{
'doc': {
path: values,
},
},
},
],
'max_query_terms': 25,
'min_doc_freq': 1,
'min_term_freq': 1,
},
})
],
'max_query_terms': 25,
'min_doc_freq': 1,
'min_term_freq': 1,
},
}
)

if not result['query']['dis_max']['queries']:
return
Expand Down Expand Up @@ -206,14 +222,11 @@ def _compile_nested(query, record):
if not value:
return

nested_query['query']['nested']['query']['bool']['must'].append({
'match': {
search_path: {
'query': value,
'operator': query_operator
}
},
})
nested_query['query']['nested']['query']['bool']['must'].append(
{
'match': {search_path: {'query': value, 'operator': query_operator}},
}
)
if "inner_hits" in query:
nested_query['query']['nested']['inner_hits'] = query['inner_hits']

Expand All @@ -228,17 +241,17 @@ def _compile_nested_prefix(query, record):
if not value:
return
if prefix_field and prefix_field in search_path:
nested_query['query']['nested']['query']['bool']['must'].append({
'match_phrase_prefix': {
search_path: value
}
})
nested_query['query']['nested']['query']['bool']['must'].append(
{'match_phrase_prefix': {search_path: value}}
)
else:
nested_query['query']['nested']['query']['bool']['must'].append({
'match': {
search_path: value,
},
})
nested_query['query']['nested']['query']['bool']['must'].append(
{
'match': {
search_path: value,
},
}
)

if "inner_hits" in query:
nested_query['query']['nested']['inner_hits'] = query['inner_hits']
Expand Down
2 changes: 1 addition & 1 deletion inspire_matcher/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from __future__ import absolute_import, division, print_function

from . import config
from inspire_matcher import config


class InspireMatcher(object):
Expand Down
6 changes: 4 additions & 2 deletions inspire_matcher/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# This file is part of INSPIRE.
# Copyright (C) 2014-2017 CERN.
#
# INSPIRE is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by
# INSPIRE is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
Expand Down Expand Up @@ -84,7 +85,8 @@ def compute_jaccard_index(x_set, y_set):
def get_tokenized_title(title):
"""Return the tokenised title.

The title is lowercased and split on the spaces. Then, duplicate tokens are removed by adding the tokens to a set.
The title is lowercased and split on the spaces. Then, duplicate
tokens are removed by adding the tokens to a set.

Args:
title (string): a title.
Expand Down
43 changes: 29 additions & 14 deletions inspire_matcher/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@

from inspire_utils.record import get_value

from .utils import (
compute_author_match_score,
compute_title_score,
)
from inspire_matcher.utils import compute_author_match_score, compute_title_score


def default_validator(record, result):
Expand All @@ -41,14 +38,18 @@ def default_validator(record, result):
def authors_titles_validator(record, result):
"""Compute a validation score for the possible match.

The score is based on a similarity score of the authors sets and the maximum Jaccard index found between 2 titles:
The score is based on a similarity score of the authors sets and
the maximum Jaccard index found between 2 titles:
one from the record and one from the result title sets.

If the computed score is higher than 0.5, then the match is valid, otherwise it is not.
If the computed score is higher than 0.5, then the match is valid,
otherwise it is not.

Args:
record (dict): the given record we are trying to match with similar ones in INSPIRE.
result (dict): possible match returned by the ES query that needs to be validated.
record (dict): the given record we are trying to match
with similar ones in INSPIRE.
result (dict): possible match returned by the ES query
that needs to be validated.

Returns:
bool: validation decision.
Expand All @@ -63,7 +64,9 @@ def authors_titles_validator(record, result):
result_titles = get_value(result, '_source.titles.title', [])

title_score = max(
compute_title_score(record_title, result_title, threshold=0.5, math_threshold=0.3)
compute_title_score(
record_title, result_title, threshold=0.5, math_threshold=0.3
)
for (record_title, result_title) in product(record_titles, result_titles)
)

Expand All @@ -79,19 +82,31 @@ def cds_identifier_validator(record, result):
``schema`` different from CDS.

Args:
record (dict): the given record we are trying to match with similar ones in INSPIRE.
result (dict): possible match returned by the ES query that needs to be validated.
record (dict): the given record we are trying to match with
similar ones in INSPIRE.
result (dict): possible match returned by the ES query
that needs to be validated.

Returns:
bool: validation decision.

"""

record_external_identifiers = get_value(record, 'external_system_identifiers', [])
result_external_identifiers = get_value(result, '_source.external_system_identifiers', [])
result_external_identifiers = get_value(
result, '_source.external_system_identifiers', []
)

record_external_identifiers = {external_id["value"] for external_id in record_external_identifiers if external_id["schema"] == 'CDS'}
result_external_identifiers = {external_id["value"] for external_id in result_external_identifiers if external_id["schema"] == 'CDS'}
record_external_identifiers = {
external_id["value"]
for external_id in record_external_identifiers
if external_id["schema"] == 'CDS'
}
result_external_identifiers = {
external_id["value"]
for external_id in result_external_identifiers
if external_id["schema"] == 'CDS'
}

return bool(record_external_identifiers & result_external_identifiers)

Expand Down
Loading
Loading